diff --git a/data/transform/check_sequence_length_transform.py b/data/transform/check_sequence_length_transform.py index 79ff33986684a00968d66ee9d1ee556ff791df48..961610ad867f8009ae1e25865efdce1ec2039974 100644 --- a/data/transform/check_sequence_length_transform.py +++ b/data/transform/check_sequence_length_transform.py @@ -1,6 +1,3 @@ -import torch - - class CheckSequenceLenghtTransform(): # TODO: make this maps actually properties of the embedding class or decouple them from this module @@ -67,26 +64,9 @@ class CheckSequenceLenghtTransform(): for i in range(min(len(self.substitution_level), max_depth)): sub_diff = self.substitution_level[i] conv_fac = self.convolution_factor[i] - dep_level = i + 1 - - if sub_diff == 0: - num_vectors = torch.sum(torch.from_numpy(dep) == dep_level) // conv_fac - elif sub_diff == 1: - val_1 = torch.from_numpy(val)[torch.from_numpy(dep) == (dep_level - 1)] - num_vectors = (val_1.view(-1, conv_fac) == 2).max(dim=-1)[0].sum() - elif sub_diff == 2: - val_1 = torch.from_numpy(val)[torch.from_numpy(dep) == (dep_level - 1)] - val_2 = torch.from_numpy(val)[torch.from_numpy(dep) == (dep_level - 2)] - mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0] - mask_2 = torch.zeros_like(val_2, dtype=torch.bool) - mask_2[val_2 == 2] = mask_1 - mask_2 = mask_2.view(-1, conv_fac).max(dim=-1)[0] - num_vectors = mask_2.sum() - else: - print("ERROR: substitution factors bigger than 2 are not implemented") - return None + dep_level = i + 1 - sub_diff - sum_sequence_length += num_vectors + sum_sequence_length += len(dep[dep == dep_level]) // conv_fac if sum_sequence_length > self.num_positions: return None diff --git a/modules/generative_head/composite_head_A.py b/modules/generative_head/composite_head_A.py index d10911e2c4c07b129101634802d5c40daac17d63..09a470984c3df728817bc4f8b3ae4e9efa185151 100644 --- a/modules/generative_head/composite_head_A.py +++ b/modules/generative_head/composite_head_A.py @@ -107,10 +107,7 @@ class CompositeHeadA(nn.Module): layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]) layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]) # compute number of vectors in latent vector of current layer - val_1 = val[dep == (layer_depth - 1)] - num_vectors = (val_1.view(-1, self.reduction_factor[layer_depth]) == 2).max(dim=-1)[0].sum() - - # num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] elif layer_depth in (7, 8): # handle substitution # get value, depth and position sequence of previous and current layer layer_val = torch.cat( @@ -135,14 +132,7 @@ class CompositeHeadA(nn.Module): ] ) # compute number of vectors in latent vector of current layer - val_1 = val[dep == (layer_depth - 1)] - val_2 = val[dep == (layer_depth - 2)] - mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0] - mask_2 = torch.zeros_like(val_2, dtype=torch.bool) - mask_2[val_2 == 2] = mask_1 - num_vectors = mask_2.view(-1, self.reduction_factor[layer_depth]).max(dim=-1)[0].sum() - - # num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] # filter latent vector of current layer layer_vec = latent_vec[vector_idx:vector_idx + num_vectors] diff --git a/modules/generative_head/composite_head_B.py b/modules/generative_head/composite_head_B.py index 635ac11267efb8aabbccf4517d3697c5d9e17c7d..14133ff5fd786295d762648f4526474f403c58b5 100644 --- a/modules/generative_head/composite_head_B.py +++ b/modules/generative_head/composite_head_B.py @@ -100,10 +100,7 @@ class CompositeHeadB(nn.Module): layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]) layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]) # compute number of vectors in latent vector of current layer - val_1 = val[dep == (layer_depth - 1)] - num_vectors = (val_1.view(-1, self.reduction_factor[layer_depth]) == 2).max(dim=-1)[0].sum() - - # num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] elif layer_depth in (7, 8): # handle substitution # get value, depth and position sequence of previous and current layer layer_val = torch.cat( @@ -128,14 +125,7 @@ class CompositeHeadB(nn.Module): ] ) # compute number of vectors in latent vector of current layer - val_1 = val[dep == (layer_depth - 1)] - val_2 = val[dep == (layer_depth - 2)] - mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0] - mask_2 = torch.zeros_like(val_2, dtype=torch.bool) - mask_2[val_2 == 2] = mask_1 - num_vectors = mask_2.view(-1, self.reduction_factor[layer_depth]).max(dim=-1)[0].sum() - - # num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] # filter latent vector of current layer layer_vec = latent_vec[vector_idx:vector_idx + num_vectors] diff --git a/modules/generative_head/composite_head_C.py b/modules/generative_head/composite_head_C.py index cfd4e71bace3909ad76e4f76ca2fa246743293a2..39f58f02aa1e86a4aaba49324a61b92698ebdc06 100644 --- a/modules/generative_head/composite_head_C.py +++ b/modules/generative_head/composite_head_C.py @@ -100,10 +100,7 @@ class CompositeHeadC(nn.Module): layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]) layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]) # compute number of vectors in latent vector of current layer - val_1 = val[dep == (layer_depth - 1)] - num_vectors = (val_1.view(-1, self.reduction_factor[layer_depth]) == 2).max(dim=-1)[0].sum() - - # num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] elif layer_depth in (7, 8): # handle double substitution # get value, depth and position sequence of previous and current layer layer_val = torch.cat( @@ -127,15 +124,8 @@ class CompositeHeadC(nn.Module): pos[dep == layer_depth], ] ) - val_1 = val[dep == (layer_depth - 1)] - val_2 = val[dep == (layer_depth - 2)] - mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0] - mask_2 = torch.zeros_like(val_2, dtype=torch.bool) - mask_2[val_2 == 2] = mask_1 - num_vectors = mask_2.view(-1, self.reduction_factor[layer_depth]).max(dim=-1)[0].sum() - # compute number of vectors in latent vector of current layer - # num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] # filter latent vector of current layer layer_vec = latent_vec[vector_idx:vector_idx + num_vectors] diff --git a/modules/generative_head/composite_head_D.py b/modules/generative_head/composite_head_D.py index ae09651393aabe4e7b578a551e95cbdcf3e301a8..480e6a6aa85b1f9025be3b220df5d28595f5db97 100644 --- a/modules/generative_head/composite_head_D.py +++ b/modules/generative_head/composite_head_D.py @@ -107,10 +107,7 @@ class CompositeHeadD(nn.Module): layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]) layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]) # compute number of vectors in latent vector of current layer - val_1 = val[dep == (layer_depth - 1)] - num_vectors = (val_1.view(-1, self.reduction_factor[layer_depth]) == 2).max(dim=-1)[0].sum() - - # num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] elif layer_depth in (7, 8): # handle double substitution # get value, depth and position sequence of previous and current layer layer_val = torch.cat( @@ -134,15 +131,8 @@ class CompositeHeadD(nn.Module): pos[dep == layer_depth], ] ) - val_1 = val[dep == (layer_depth - 1)] - val_2 = val[dep == (layer_depth - 2)] - mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0] - mask_2 = torch.zeros_like(val_2, dtype=torch.bool) - mask_2[val_2 == 2] = mask_1 - num_vectors = mask_2.view(-1, self.reduction_factor[layer_depth]).max(dim=-1)[0].sum() - # compute number of vectors in latent vector of current layer - # num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] # filter latent vector of current layer layer_vec = latent_vec[vector_idx:vector_idx + num_vectors] diff --git a/modules/generative_head/double_substitution_head.py b/modules/generative_head/double_substitution_head.py index da187612938c276757b6c6ea78193c4a5befe925..e1d41dda06ac2a57fceb31a122f22c3ccd2cc2fb 100644 --- a/modules/generative_head/double_substitution_head.py +++ b/modules/generative_head/double_substitution_head.py @@ -20,7 +20,6 @@ class DoubleSubstitutionHead(nn.Module): """ super(DoubleSubstitutionHead, self).__init__() self.head_dim = head_dim - self.conv_size = conv_size deconvolution_2 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)] for i in range(n_layer - 1): @@ -115,14 +114,6 @@ class DoubleSubstitutionHead(nn.Module): # substitute all mixed token embeddings of third to last layer, with token embeddings of penultimate layer emb_2[val_2 == 2] = self.down_convolution_1(emb_1) # [N, T1, C] - # filter out all tokens, that do not have any descendants in last layer - mask_1 = (val_1.view(batch_size, -1, 8) == 2).max(dim=-1)[0] - mask_2 = torch.zeros_like(val_2, dtype=torch.bool) - mask_2[val_2 == 2] = mask_1 - mask_2 = mask_2.view(batch_size, -1, self.conv_size).max(dim=-1)[0] - emb_2 = emb_2.view(batch_size, -1, self.conv_size, self.head_dim)[mask_2].view(batch_size, -1, self.head_dim) - val_2 = val_2.view(batch_size, -1, self.conv_size)[mask_2].view(batch_size, -1) - emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8]) emb_1 = self.convolution_1(emb_1) emb_2 = self.convolution_2(emb_2) diff --git a/modules/generative_head/substitution_head.py b/modules/generative_head/substitution_head.py index 9fe20d44ee572c92c09434a2ceada85350587752..b4a7fd63eb91ec24377aeb99c65aba53c11bbca5 100644 --- a/modules/generative_head/substitution_head.py +++ b/modules/generative_head/substitution_head.py @@ -20,7 +20,6 @@ class SubstitutionHead(nn.Module): """ super(SubstitutionHead, self).__init__() self.head_dim = head_dim - self.conv_size = conv_size deconvolution_1 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)] for i in range(n_layer - 1): @@ -98,11 +97,6 @@ class SubstitutionHead(nn.Module): # substitite all mixed token embeddings of penultimate layer, with token embeddings of last layer emb_1[val_1 == 2] = self.down_convolution(emb_0[:, :mix_1 * 8]) # [N, T1, C] - # filter out all tokens, that do not have any descendants in last layer - mask = (val_1.view(batch_size, -1, self.conv_size) == 2).max(dim=-1)[0] - emb_1 = emb_1.view(batch_size, -1, self.conv_size, self.head_dim)[mask].view(batch_size, -1, self.head_dim) - val_1 = val_1.view(batch_size, -1, self.conv_size)[mask].view(batch_size, -1) - emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8]) emb_1 = self.convolution_1(emb_1) diff --git a/modules/token_embedding/double_substitution_embedding.py b/modules/token_embedding/double_substitution_embedding.py index bad4581c01c70da381933bdd57177ed65c7a80aa..868a91fad68ca37a01ffd7826ddff5e94f7ec2a3 100644 --- a/modules/token_embedding/double_substitution_embedding.py +++ b/modules/token_embedding/double_substitution_embedding.py @@ -97,6 +97,9 @@ class DoubleSubstitutionEmbedding(nn.Module): dep_0[i, :len_0[i]] = depth[i, len_2[i] + len_1[i]:len_2[i] + len_1[i] + len_0[i]] pos_0[i, :len_0[i]] = position[i, len_2[i] + len_1[i]:len_2[i] + len_1[i] + len_0[i]] + # precompute padding mask + self.mask = padding_mask(val_2[:, ::self.conv_size], device=value.device) # [N, S'_2, E] + # convolute embedded tokens of last layer y_0 = self.convolution_0(x_0) # [N, S'_0, E // 4] # substitite all mixed token embeddings of second-last layer, with token embeddings of last layer @@ -108,25 +111,7 @@ class DoubleSubstitutionEmbedding(nn.Module): x_2[val_2 == 2] = y_1[val_1[:, ::8] != 0] # [N, S_2, E // 2] # convolute substituted tokens of second-last layer - x_out = self.convolution_2(x_2.contiguous()) # [N, S'_2, E] - - # filter out all tokens, that do not have any descendants in last layer - mask_1 = (val_1.view(batch_size, -1, 8) == 2).max(dim=-1)[0] - mask_2 = torch.zeros_like(val_2, dtype=torch.bool) - mask_2[val_2 == 2] = mask_1 - mask_2 = mask_2.view(batch_size, -1, self.conv_size).max(dim=-1)[0] - len_out = torch.max(torch.sum(mask_2, dim=-1)).item() - - x_masked = torch.zeros(batch_size, len_out, embedding.shape[2], dtype=torch.float, device=value.device) - val_masked = torch.zeros((batch_size, len_out), dtype=torch.long, device=value.device) - for i in range(batch_size): - x_masked[i] = x_out[i, mask_2[i].nonzero().squeeze(-1)] - val_masked[i] = val_2[:, ::self.conv_size][i, mask_2[i].nonzero().squeeze(-1)] - - # precompute padding mask - self.mask = padding_mask(val_masked, device=value.device) # [N, S'_2, E] - - return x_masked + return self.convolution_2(x_2.contiguous()) # [N, S'_2, E] def forward(self, value, depth, position): """ Transform sequences into embedding space for the encoder. diff --git a/modules/token_embedding/substitution_embedding.py b/modules/token_embedding/substitution_embedding.py index b41f181d34ebc5d4262108a9dafc472e5a830757..24f15b39fc18df68904b3b8f89362b129c179eca 100644 --- a/modules/token_embedding/substitution_embedding.py +++ b/modules/token_embedding/substitution_embedding.py @@ -82,6 +82,8 @@ class SubstitutionEmbedding(nn.Module): dep_0[i, :len_0[i]] = depth[i, len_1[i]:len_1[i] + len_0[i]] pos_0[i, :len_0[i]] = position[i, len_1[i]:len_1[i] + len_0[i]] + # precompute padding mask + self.mask = padding_mask(val_1[:, ::self.conv_size], device=value.device) # convolve embedded tokens of last layer y_0 = self.convolution_0(x_0) # [N, T2', C] @@ -90,22 +92,7 @@ class SubstitutionEmbedding(nn.Module): x_1[val_1 == 2] = y_0[val_0[:, ::8] != 0] # [N, T1, C] # convolve substituted tokens of penultimate layer - x_out = self.convolution_1(x_1.contiguous()) # [N, T1', E] - - # filter out all tokens, that do not have any descendants in last layer - mask = (val_1.view(batch_size, -1, self.conv_size) == 2).max(dim=-1)[0] - len_out = torch.max(torch.sum(mask, dim=-1)).item() - - x_masked = torch.zeros(batch_size, len_out, embedding.shape[2], dtype=torch.float, device=value.device) - val_masked = torch.zeros((batch_size, len_out), dtype=torch.long, device=value.device) - for i in range(batch_size): - x_masked[i] = x_out[i, mask[i].nonzero().squeeze(-1)] - val_masked[i] = val_1[i, ::self.conv_size][mask[i].nonzero().squeeze(-1)] - - # precompute padding mask - self.mask = padding_mask(val_masked, device=value.device) - - return x_masked + return self.convolution_1(x_1.contiguous()) # [N, T1', E] def forward(self, value, depth, position): """ Transform sequences into embedding space for the encoder.