From 2bd5bc1f4f92e5a6ef4de8b45698358ecb24a50d Mon Sep 17 00:00:00 2001 From: Moritz Ibing <ibing@cs.rwth-aachen.de> Date: Thu, 24 Feb 2022 09:16:42 +0100 Subject: [PATCH] revert latent filtering --- .../check_sequence_length_transform.py | 24 ++----------------- modules/generative_head/composite_head_A.py | 14 ++--------- modules/generative_head/composite_head_B.py | 14 ++--------- modules/generative_head/composite_head_C.py | 14 ++--------- modules/generative_head/composite_head_D.py | 14 ++--------- .../double_substitution_head.py | 9 ------- modules/generative_head/substitution_head.py | 6 ----- .../double_substitution_embedding.py | 23 ++++-------------- .../token_embedding/substitution_embedding.py | 19 +++------------ 9 files changed, 17 insertions(+), 120 deletions(-) diff --git a/data/transform/check_sequence_length_transform.py b/data/transform/check_sequence_length_transform.py index 79ff339..961610a 100644 --- a/data/transform/check_sequence_length_transform.py +++ b/data/transform/check_sequence_length_transform.py @@ -1,6 +1,3 @@ -import torch - - class CheckSequenceLenghtTransform(): # TODO: make this maps actually properties of the embedding class or decouple them from this module @@ -67,26 +64,9 @@ class CheckSequenceLenghtTransform(): for i in range(min(len(self.substitution_level), max_depth)): sub_diff = self.substitution_level[i] conv_fac = self.convolution_factor[i] - dep_level = i + 1 - - if sub_diff == 0: - num_vectors = torch.sum(torch.from_numpy(dep) == dep_level) // conv_fac - elif sub_diff == 1: - val_1 = torch.from_numpy(val)[torch.from_numpy(dep) == (dep_level - 1)] - num_vectors = (val_1.view(-1, conv_fac) == 2).max(dim=-1)[0].sum() - elif sub_diff == 2: - val_1 = torch.from_numpy(val)[torch.from_numpy(dep) == (dep_level - 1)] - val_2 = torch.from_numpy(val)[torch.from_numpy(dep) == (dep_level - 2)] - mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0] - mask_2 = torch.zeros_like(val_2, dtype=torch.bool) - mask_2[val_2 == 2] = mask_1 - mask_2 = mask_2.view(-1, conv_fac).max(dim=-1)[0] - num_vectors = mask_2.sum() - else: - print("ERROR: substitution factors bigger than 2 are not implemented") - return None + dep_level = i + 1 - sub_diff - sum_sequence_length += num_vectors + sum_sequence_length += len(dep[dep == dep_level]) // conv_fac if sum_sequence_length > self.num_positions: return None diff --git a/modules/generative_head/composite_head_A.py b/modules/generative_head/composite_head_A.py index d10911e..09a4709 100644 --- a/modules/generative_head/composite_head_A.py +++ b/modules/generative_head/composite_head_A.py @@ -107,10 +107,7 @@ class CompositeHeadA(nn.Module): layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]) layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]) # compute number of vectors in latent vector of current layer - val_1 = val[dep == (layer_depth - 1)] - num_vectors = (val_1.view(-1, self.reduction_factor[layer_depth]) == 2).max(dim=-1)[0].sum() - - # num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] elif layer_depth in (7, 8): # handle substitution # get value, depth and position sequence of previous and current layer layer_val = torch.cat( @@ -135,14 +132,7 @@ class CompositeHeadA(nn.Module): ] ) # compute number of vectors in latent vector of current layer - val_1 = val[dep == (layer_depth - 1)] - val_2 = val[dep == (layer_depth - 2)] - mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0] - mask_2 = torch.zeros_like(val_2, dtype=torch.bool) - mask_2[val_2 == 2] = mask_1 - num_vectors = mask_2.view(-1, self.reduction_factor[layer_depth]).max(dim=-1)[0].sum() - - # num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] # filter latent vector of current layer layer_vec = latent_vec[vector_idx:vector_idx + num_vectors] diff --git a/modules/generative_head/composite_head_B.py b/modules/generative_head/composite_head_B.py index 635ac11..14133ff 100644 --- a/modules/generative_head/composite_head_B.py +++ b/modules/generative_head/composite_head_B.py @@ -100,10 +100,7 @@ class CompositeHeadB(nn.Module): layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]) layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]) # compute number of vectors in latent vector of current layer - val_1 = val[dep == (layer_depth - 1)] - num_vectors = (val_1.view(-1, self.reduction_factor[layer_depth]) == 2).max(dim=-1)[0].sum() - - # num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] elif layer_depth in (7, 8): # handle substitution # get value, depth and position sequence of previous and current layer layer_val = torch.cat( @@ -128,14 +125,7 @@ class CompositeHeadB(nn.Module): ] ) # compute number of vectors in latent vector of current layer - val_1 = val[dep == (layer_depth - 1)] - val_2 = val[dep == (layer_depth - 2)] - mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0] - mask_2 = torch.zeros_like(val_2, dtype=torch.bool) - mask_2[val_2 == 2] = mask_1 - num_vectors = mask_2.view(-1, self.reduction_factor[layer_depth]).max(dim=-1)[0].sum() - - # num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] # filter latent vector of current layer layer_vec = latent_vec[vector_idx:vector_idx + num_vectors] diff --git a/modules/generative_head/composite_head_C.py b/modules/generative_head/composite_head_C.py index cfd4e71..39f58f0 100644 --- a/modules/generative_head/composite_head_C.py +++ b/modules/generative_head/composite_head_C.py @@ -100,10 +100,7 @@ class CompositeHeadC(nn.Module): layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]) layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]) # compute number of vectors in latent vector of current layer - val_1 = val[dep == (layer_depth - 1)] - num_vectors = (val_1.view(-1, self.reduction_factor[layer_depth]) == 2).max(dim=-1)[0].sum() - - # num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] elif layer_depth in (7, 8): # handle double substitution # get value, depth and position sequence of previous and current layer layer_val = torch.cat( @@ -127,15 +124,8 @@ class CompositeHeadC(nn.Module): pos[dep == layer_depth], ] ) - val_1 = val[dep == (layer_depth - 1)] - val_2 = val[dep == (layer_depth - 2)] - mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0] - mask_2 = torch.zeros_like(val_2, dtype=torch.bool) - mask_2[val_2 == 2] = mask_1 - num_vectors = mask_2.view(-1, self.reduction_factor[layer_depth]).max(dim=-1)[0].sum() - # compute number of vectors in latent vector of current layer - # num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] # filter latent vector of current layer layer_vec = latent_vec[vector_idx:vector_idx + num_vectors] diff --git a/modules/generative_head/composite_head_D.py b/modules/generative_head/composite_head_D.py index ae09651..480e6a6 100644 --- a/modules/generative_head/composite_head_D.py +++ b/modules/generative_head/composite_head_D.py @@ -107,10 +107,7 @@ class CompositeHeadD(nn.Module): layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]) layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]) # compute number of vectors in latent vector of current layer - val_1 = val[dep == (layer_depth - 1)] - num_vectors = (val_1.view(-1, self.reduction_factor[layer_depth]) == 2).max(dim=-1)[0].sum() - - # num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] elif layer_depth in (7, 8): # handle double substitution # get value, depth and position sequence of previous and current layer layer_val = torch.cat( @@ -134,15 +131,8 @@ class CompositeHeadD(nn.Module): pos[dep == layer_depth], ] ) - val_1 = val[dep == (layer_depth - 1)] - val_2 = val[dep == (layer_depth - 2)] - mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0] - mask_2 = torch.zeros_like(val_2, dtype=torch.bool) - mask_2[val_2 == 2] = mask_1 - num_vectors = mask_2.view(-1, self.reduction_factor[layer_depth]).max(dim=-1)[0].sum() - # compute number of vectors in latent vector of current layer - # num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] + num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] # filter latent vector of current layer layer_vec = latent_vec[vector_idx:vector_idx + num_vectors] diff --git a/modules/generative_head/double_substitution_head.py b/modules/generative_head/double_substitution_head.py index da18761..e1d41dd 100644 --- a/modules/generative_head/double_substitution_head.py +++ b/modules/generative_head/double_substitution_head.py @@ -20,7 +20,6 @@ class DoubleSubstitutionHead(nn.Module): """ super(DoubleSubstitutionHead, self).__init__() self.head_dim = head_dim - self.conv_size = conv_size deconvolution_2 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)] for i in range(n_layer - 1): @@ -115,14 +114,6 @@ class DoubleSubstitutionHead(nn.Module): # substitute all mixed token embeddings of third to last layer, with token embeddings of penultimate layer emb_2[val_2 == 2] = self.down_convolution_1(emb_1) # [N, T1, C] - # filter out all tokens, that do not have any descendants in last layer - mask_1 = (val_1.view(batch_size, -1, 8) == 2).max(dim=-1)[0] - mask_2 = torch.zeros_like(val_2, dtype=torch.bool) - mask_2[val_2 == 2] = mask_1 - mask_2 = mask_2.view(batch_size, -1, self.conv_size).max(dim=-1)[0] - emb_2 = emb_2.view(batch_size, -1, self.conv_size, self.head_dim)[mask_2].view(batch_size, -1, self.head_dim) - val_2 = val_2.view(batch_size, -1, self.conv_size)[mask_2].view(batch_size, -1) - emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8]) emb_1 = self.convolution_1(emb_1) emb_2 = self.convolution_2(emb_2) diff --git a/modules/generative_head/substitution_head.py b/modules/generative_head/substitution_head.py index 9fe20d4..b4a7fd6 100644 --- a/modules/generative_head/substitution_head.py +++ b/modules/generative_head/substitution_head.py @@ -20,7 +20,6 @@ class SubstitutionHead(nn.Module): """ super(SubstitutionHead, self).__init__() self.head_dim = head_dim - self.conv_size = conv_size deconvolution_1 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)] for i in range(n_layer - 1): @@ -98,11 +97,6 @@ class SubstitutionHead(nn.Module): # substitite all mixed token embeddings of penultimate layer, with token embeddings of last layer emb_1[val_1 == 2] = self.down_convolution(emb_0[:, :mix_1 * 8]) # [N, T1, C] - # filter out all tokens, that do not have any descendants in last layer - mask = (val_1.view(batch_size, -1, self.conv_size) == 2).max(dim=-1)[0] - emb_1 = emb_1.view(batch_size, -1, self.conv_size, self.head_dim)[mask].view(batch_size, -1, self.head_dim) - val_1 = val_1.view(batch_size, -1, self.conv_size)[mask].view(batch_size, -1) - emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8]) emb_1 = self.convolution_1(emb_1) diff --git a/modules/token_embedding/double_substitution_embedding.py b/modules/token_embedding/double_substitution_embedding.py index bad4581..868a91f 100644 --- a/modules/token_embedding/double_substitution_embedding.py +++ b/modules/token_embedding/double_substitution_embedding.py @@ -97,6 +97,9 @@ class DoubleSubstitutionEmbedding(nn.Module): dep_0[i, :len_0[i]] = depth[i, len_2[i] + len_1[i]:len_2[i] + len_1[i] + len_0[i]] pos_0[i, :len_0[i]] = position[i, len_2[i] + len_1[i]:len_2[i] + len_1[i] + len_0[i]] + # precompute padding mask + self.mask = padding_mask(val_2[:, ::self.conv_size], device=value.device) # [N, S'_2, E] + # convolute embedded tokens of last layer y_0 = self.convolution_0(x_0) # [N, S'_0, E // 4] # substitite all mixed token embeddings of second-last layer, with token embeddings of last layer @@ -108,25 +111,7 @@ class DoubleSubstitutionEmbedding(nn.Module): x_2[val_2 == 2] = y_1[val_1[:, ::8] != 0] # [N, S_2, E // 2] # convolute substituted tokens of second-last layer - x_out = self.convolution_2(x_2.contiguous()) # [N, S'_2, E] - - # filter out all tokens, that do not have any descendants in last layer - mask_1 = (val_1.view(batch_size, -1, 8) == 2).max(dim=-1)[0] - mask_2 = torch.zeros_like(val_2, dtype=torch.bool) - mask_2[val_2 == 2] = mask_1 - mask_2 = mask_2.view(batch_size, -1, self.conv_size).max(dim=-1)[0] - len_out = torch.max(torch.sum(mask_2, dim=-1)).item() - - x_masked = torch.zeros(batch_size, len_out, embedding.shape[2], dtype=torch.float, device=value.device) - val_masked = torch.zeros((batch_size, len_out), dtype=torch.long, device=value.device) - for i in range(batch_size): - x_masked[i] = x_out[i, mask_2[i].nonzero().squeeze(-1)] - val_masked[i] = val_2[:, ::self.conv_size][i, mask_2[i].nonzero().squeeze(-1)] - - # precompute padding mask - self.mask = padding_mask(val_masked, device=value.device) # [N, S'_2, E] - - return x_masked + return self.convolution_2(x_2.contiguous()) # [N, S'_2, E] def forward(self, value, depth, position): """ Transform sequences into embedding space for the encoder. diff --git a/modules/token_embedding/substitution_embedding.py b/modules/token_embedding/substitution_embedding.py index b41f181..24f15b3 100644 --- a/modules/token_embedding/substitution_embedding.py +++ b/modules/token_embedding/substitution_embedding.py @@ -82,6 +82,8 @@ class SubstitutionEmbedding(nn.Module): dep_0[i, :len_0[i]] = depth[i, len_1[i]:len_1[i] + len_0[i]] pos_0[i, :len_0[i]] = position[i, len_1[i]:len_1[i] + len_0[i]] + # precompute padding mask + self.mask = padding_mask(val_1[:, ::self.conv_size], device=value.device) # convolve embedded tokens of last layer y_0 = self.convolution_0(x_0) # [N, T2', C] @@ -90,22 +92,7 @@ class SubstitutionEmbedding(nn.Module): x_1[val_1 == 2] = y_0[val_0[:, ::8] != 0] # [N, T1, C] # convolve substituted tokens of penultimate layer - x_out = self.convolution_1(x_1.contiguous()) # [N, T1', E] - - # filter out all tokens, that do not have any descendants in last layer - mask = (val_1.view(batch_size, -1, self.conv_size) == 2).max(dim=-1)[0] - len_out = torch.max(torch.sum(mask, dim=-1)).item() - - x_masked = torch.zeros(batch_size, len_out, embedding.shape[2], dtype=torch.float, device=value.device) - val_masked = torch.zeros((batch_size, len_out), dtype=torch.long, device=value.device) - for i in range(batch_size): - x_masked[i] = x_out[i, mask[i].nonzero().squeeze(-1)] - val_masked[i] = val_1[i, ::self.conv_size][mask[i].nonzero().squeeze(-1)] - - # precompute padding mask - self.mask = padding_mask(val_masked, device=value.device) - - return x_masked + return self.convolution_1(x_1.contiguous()) # [N, T1', E] def forward(self, value, depth, position): """ Transform sequences into embedding space for the encoder. -- GitLab