From 2bd5bc1f4f92e5a6ef4de8b45698358ecb24a50d Mon Sep 17 00:00:00 2001
From: Moritz Ibing <ibing@cs.rwth-aachen.de>
Date: Thu, 24 Feb 2022 09:16:42 +0100
Subject: [PATCH] revert latent filtering

---
 .../check_sequence_length_transform.py        | 24 ++-----------------
 modules/generative_head/composite_head_A.py   | 14 ++---------
 modules/generative_head/composite_head_B.py   | 14 ++---------
 modules/generative_head/composite_head_C.py   | 14 ++---------
 modules/generative_head/composite_head_D.py   | 14 ++---------
 .../double_substitution_head.py               |  9 -------
 modules/generative_head/substitution_head.py  |  6 -----
 .../double_substitution_embedding.py          | 23 ++++--------------
 .../token_embedding/substitution_embedding.py | 19 +++------------
 9 files changed, 17 insertions(+), 120 deletions(-)

diff --git a/data/transform/check_sequence_length_transform.py b/data/transform/check_sequence_length_transform.py
index 79ff339..961610a 100644
--- a/data/transform/check_sequence_length_transform.py
+++ b/data/transform/check_sequence_length_transform.py
@@ -1,6 +1,3 @@
-import torch
-
-
 class CheckSequenceLenghtTransform():
 
     # TODO: make this maps actually properties of the embedding class or decouple them from this module
@@ -67,26 +64,9 @@ class CheckSequenceLenghtTransform():
         for i in range(min(len(self.substitution_level), max_depth)):
             sub_diff = self.substitution_level[i]
             conv_fac = self.convolution_factor[i]
-            dep_level = i + 1
-
-            if sub_diff == 0:
-                num_vectors = torch.sum(torch.from_numpy(dep) == dep_level) // conv_fac
-            elif sub_diff == 1:
-                val_1 = torch.from_numpy(val)[torch.from_numpy(dep) == (dep_level - 1)]
-                num_vectors = (val_1.view(-1, conv_fac) == 2).max(dim=-1)[0].sum()
-            elif sub_diff == 2:
-                val_1 = torch.from_numpy(val)[torch.from_numpy(dep) == (dep_level - 1)]
-                val_2 = torch.from_numpy(val)[torch.from_numpy(dep) == (dep_level - 2)]
-                mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0]
-                mask_2 = torch.zeros_like(val_2, dtype=torch.bool)
-                mask_2[val_2 == 2] = mask_1
-                mask_2 = mask_2.view(-1, conv_fac).max(dim=-1)[0]
-                num_vectors = mask_2.sum()
-            else:
-                print("ERROR: substitution factors bigger than 2 are not implemented")
-                return None
+            dep_level = i + 1 - sub_diff
 
-            sum_sequence_length += num_vectors
+            sum_sequence_length += len(dep[dep == dep_level]) // conv_fac
 
         if sum_sequence_length > self.num_positions:
             return None
diff --git a/modules/generative_head/composite_head_A.py b/modules/generative_head/composite_head_A.py
index d10911e..09a4709 100644
--- a/modules/generative_head/composite_head_A.py
+++ b/modules/generative_head/composite_head_A.py
@@ -107,10 +107,7 @@ class CompositeHeadA(nn.Module):
                     layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]])
                     layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]])
                     # compute number of vectors in latent vector of current layer
-                    val_1 = val[dep == (layer_depth - 1)]
-                    num_vectors = (val_1.view(-1, self.reduction_factor[layer_depth]) == 2).max(dim=-1)[0].sum()
-
-                    # num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth]
+                    num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth]
                 elif layer_depth in (7, 8):  # handle substitution
                     # get value, depth and position sequence of previous and current layer
                     layer_val = torch.cat(
@@ -135,14 +132,7 @@ class CompositeHeadA(nn.Module):
                         ]
                     )
                     # compute number of vectors in latent vector of current layer
-                    val_1 = val[dep == (layer_depth - 1)]
-                    val_2 = val[dep == (layer_depth - 2)]
-                    mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0]
-                    mask_2 = torch.zeros_like(val_2, dtype=torch.bool)
-                    mask_2[val_2 == 2] = mask_1
-                    num_vectors = mask_2.view(-1, self.reduction_factor[layer_depth]).max(dim=-1)[0].sum()
-
-                    # num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth]
+                    num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth]
 
                 # filter latent vector of current layer
                 layer_vec = latent_vec[vector_idx:vector_idx + num_vectors]
diff --git a/modules/generative_head/composite_head_B.py b/modules/generative_head/composite_head_B.py
index 635ac11..14133ff 100644
--- a/modules/generative_head/composite_head_B.py
+++ b/modules/generative_head/composite_head_B.py
@@ -100,10 +100,7 @@ class CompositeHeadB(nn.Module):
                     layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]])
                     layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]])
                     # compute number of vectors in latent vector of current layer
-                    val_1 = val[dep == (layer_depth - 1)]
-                    num_vectors = (val_1.view(-1, self.reduction_factor[layer_depth]) == 2).max(dim=-1)[0].sum()
-
-                    # num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth]
+                    num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth]
                 elif layer_depth in (7, 8):  # handle substitution
                     # get value, depth and position sequence of previous and current layer
                     layer_val = torch.cat(
@@ -128,14 +125,7 @@ class CompositeHeadB(nn.Module):
                         ]
                     )
                     # compute number of vectors in latent vector of current layer
-                    val_1 = val[dep == (layer_depth - 1)]
-                    val_2 = val[dep == (layer_depth - 2)]
-                    mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0]
-                    mask_2 = torch.zeros_like(val_2, dtype=torch.bool)
-                    mask_2[val_2 == 2] = mask_1
-                    num_vectors = mask_2.view(-1, self.reduction_factor[layer_depth]).max(dim=-1)[0].sum()
-
-                    # num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth]
+                    num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth]
 
                 # filter latent vector of current layer
                 layer_vec = latent_vec[vector_idx:vector_idx + num_vectors]
diff --git a/modules/generative_head/composite_head_C.py b/modules/generative_head/composite_head_C.py
index cfd4e71..39f58f0 100644
--- a/modules/generative_head/composite_head_C.py
+++ b/modules/generative_head/composite_head_C.py
@@ -100,10 +100,7 @@ class CompositeHeadC(nn.Module):
                     layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]])
                     layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]])
                     # compute number of vectors in latent vector of current layer
-                    val_1 = val[dep == (layer_depth - 1)]
-                    num_vectors = (val_1.view(-1, self.reduction_factor[layer_depth]) == 2).max(dim=-1)[0].sum()
-
-                    # num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth]
+                    num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth]
                 elif layer_depth in (7, 8):  # handle double substitution
                     # get value, depth and position sequence of previous and current layer
                     layer_val = torch.cat(
@@ -127,15 +124,8 @@ class CompositeHeadC(nn.Module):
                             pos[dep == layer_depth],
                         ]
                     )
-                    val_1 = val[dep == (layer_depth - 1)]
-                    val_2 = val[dep == (layer_depth - 2)]
-                    mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0]
-                    mask_2 = torch.zeros_like(val_2, dtype=torch.bool)
-                    mask_2[val_2 == 2] = mask_1
-                    num_vectors = mask_2.view(-1, self.reduction_factor[layer_depth]).max(dim=-1)[0].sum()
-
                     # compute number of vectors in latent vector of current layer
-                    # num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth]
+                    num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth]
 
                 # filter latent vector of current layer
                 layer_vec = latent_vec[vector_idx:vector_idx + num_vectors]
diff --git a/modules/generative_head/composite_head_D.py b/modules/generative_head/composite_head_D.py
index ae09651..480e6a6 100644
--- a/modules/generative_head/composite_head_D.py
+++ b/modules/generative_head/composite_head_D.py
@@ -107,10 +107,7 @@ class CompositeHeadD(nn.Module):
                     layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]])
                     layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]])
                     # compute number of vectors in latent vector of current layer
-                    val_1 = val[dep == (layer_depth - 1)]
-                    num_vectors = (val_1.view(-1, self.reduction_factor[layer_depth]) == 2).max(dim=-1)[0].sum()
-
-                    # num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth]
+                    num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth]
                 elif layer_depth in (7, 8):  # handle double substitution
                     # get value, depth and position sequence of previous and current layer
                     layer_val = torch.cat(
@@ -134,15 +131,8 @@ class CompositeHeadD(nn.Module):
                             pos[dep == layer_depth],
                         ]
                     )
-                    val_1 = val[dep == (layer_depth - 1)]
-                    val_2 = val[dep == (layer_depth - 2)]
-                    mask_1 = (val_1.view(-1, 8) == 2).max(dim=-1)[0]
-                    mask_2 = torch.zeros_like(val_2, dtype=torch.bool)
-                    mask_2[val_2 == 2] = mask_1
-                    num_vectors = mask_2.view(-1, self.reduction_factor[layer_depth]).max(dim=-1)[0].sum()
-
                     # compute number of vectors in latent vector of current layer
-                    # num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth]
+                    num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth]
 
                 # filter latent vector of current layer
                 layer_vec = latent_vec[vector_idx:vector_idx + num_vectors]
diff --git a/modules/generative_head/double_substitution_head.py b/modules/generative_head/double_substitution_head.py
index da18761..e1d41dd 100644
--- a/modules/generative_head/double_substitution_head.py
+++ b/modules/generative_head/double_substitution_head.py
@@ -20,7 +20,6 @@ class DoubleSubstitutionHead(nn.Module):
         """
         super(DoubleSubstitutionHead, self).__init__()
         self.head_dim = head_dim
-        self.conv_size = conv_size
 
         deconvolution_2 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)]
         for i in range(n_layer - 1):
@@ -115,14 +114,6 @@ class DoubleSubstitutionHead(nn.Module):
         # substitute all mixed token embeddings of third to last layer, with token embeddings of penultimate layer
         emb_2[val_2 == 2] = self.down_convolution_1(emb_1)  # [N, T1, C]
 
-        # filter out all tokens, that do not have any descendants in last layer
-        mask_1 = (val_1.view(batch_size, -1, 8) == 2).max(dim=-1)[0]
-        mask_2 = torch.zeros_like(val_2, dtype=torch.bool)
-        mask_2[val_2 == 2] = mask_1
-        mask_2 = mask_2.view(batch_size, -1, self.conv_size).max(dim=-1)[0]
-        emb_2 = emb_2.view(batch_size, -1, self.conv_size, self.head_dim)[mask_2].view(batch_size, -1, self.head_dim)
-        val_2 = val_2.view(batch_size, -1, self.conv_size)[mask_2].view(batch_size, -1)
-
         emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8])
         emb_1 = self.convolution_1(emb_1)
         emb_2 = self.convolution_2(emb_2)
diff --git a/modules/generative_head/substitution_head.py b/modules/generative_head/substitution_head.py
index 9fe20d4..b4a7fd6 100644
--- a/modules/generative_head/substitution_head.py
+++ b/modules/generative_head/substitution_head.py
@@ -20,7 +20,6 @@ class SubstitutionHead(nn.Module):
         """
         super(SubstitutionHead, self).__init__()
         self.head_dim = head_dim
-        self.conv_size = conv_size
 
         deconvolution_1 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)]
         for i in range(n_layer - 1):
@@ -98,11 +97,6 @@ class SubstitutionHead(nn.Module):
         # substitite all mixed token embeddings of penultimate layer, with token embeddings of last layer
         emb_1[val_1 == 2] = self.down_convolution(emb_0[:, :mix_1 * 8])  # [N, T1, C]
 
-        # filter out all tokens, that do not have any descendants in last layer
-        mask = (val_1.view(batch_size, -1, self.conv_size) == 2).max(dim=-1)[0]
-        emb_1 = emb_1.view(batch_size, -1, self.conv_size, self.head_dim)[mask].view(batch_size, -1, self.head_dim)
-        val_1 = val_1.view(batch_size, -1, self.conv_size)[mask].view(batch_size, -1)
-
         emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8])
         emb_1 = self.convolution_1(emb_1)
 
diff --git a/modules/token_embedding/double_substitution_embedding.py b/modules/token_embedding/double_substitution_embedding.py
index bad4581..868a91f 100644
--- a/modules/token_embedding/double_substitution_embedding.py
+++ b/modules/token_embedding/double_substitution_embedding.py
@@ -97,6 +97,9 @@ class DoubleSubstitutionEmbedding(nn.Module):
             dep_0[i, :len_0[i]] = depth[i, len_2[i] + len_1[i]:len_2[i] + len_1[i] + len_0[i]]
             pos_0[i, :len_0[i]] = position[i, len_2[i] + len_1[i]:len_2[i] + len_1[i] + len_0[i]]
 
+        # precompute padding mask
+        self.mask = padding_mask(val_2[:, ::self.conv_size], device=value.device)  # [N, S'_2, E]
+
         # convolute embedded tokens of last layer
         y_0 = self.convolution_0(x_0)  # [N, S'_0, E // 4]
         # substitite all mixed token embeddings of second-last layer, with token embeddings of last layer
@@ -108,25 +111,7 @@ class DoubleSubstitutionEmbedding(nn.Module):
         x_2[val_2 == 2] = y_1[val_1[:, ::8] != 0]  # [N, S_2, E // 2]
 
         # convolute substituted tokens of second-last layer
-        x_out = self.convolution_2(x_2.contiguous())  # [N, S'_2, E]
-
-        # filter out all tokens, that do not have any descendants in last layer
-        mask_1 = (val_1.view(batch_size, -1, 8) == 2).max(dim=-1)[0]
-        mask_2 = torch.zeros_like(val_2, dtype=torch.bool)
-        mask_2[val_2 == 2] = mask_1
-        mask_2 = mask_2.view(batch_size, -1, self.conv_size).max(dim=-1)[0]
-        len_out = torch.max(torch.sum(mask_2, dim=-1)).item()
-
-        x_masked = torch.zeros(batch_size, len_out, embedding.shape[2], dtype=torch.float, device=value.device)
-        val_masked = torch.zeros((batch_size, len_out), dtype=torch.long, device=value.device)
-        for i in range(batch_size):
-            x_masked[i] = x_out[i, mask_2[i].nonzero().squeeze(-1)]
-            val_masked[i] = val_2[:, ::self.conv_size][i, mask_2[i].nonzero().squeeze(-1)]
-
-        # precompute padding mask
-        self.mask = padding_mask(val_masked, device=value.device)  # [N, S'_2, E]
-
-        return x_masked
+        return self.convolution_2(x_2.contiguous())  # [N, S'_2, E]
 
     def forward(self, value, depth, position):
         """ Transform sequences into embedding space for the encoder.
diff --git a/modules/token_embedding/substitution_embedding.py b/modules/token_embedding/substitution_embedding.py
index b41f181..24f15b3 100644
--- a/modules/token_embedding/substitution_embedding.py
+++ b/modules/token_embedding/substitution_embedding.py
@@ -82,6 +82,8 @@ class SubstitutionEmbedding(nn.Module):
             dep_0[i, :len_0[i]] = depth[i, len_1[i]:len_1[i] + len_0[i]]
             pos_0[i, :len_0[i]] = position[i, len_1[i]:len_1[i] + len_0[i]]
 
+        # precompute padding mask
+        self.mask = padding_mask(val_1[:, ::self.conv_size], device=value.device)
 
         # convolve embedded tokens of last layer
         y_0 = self.convolution_0(x_0)  # [N, T2', C]
@@ -90,22 +92,7 @@ class SubstitutionEmbedding(nn.Module):
         x_1[val_1 == 2] = y_0[val_0[:, ::8] != 0]  # [N, T1, C]
 
         # convolve substituted tokens of penultimate layer
-        x_out = self.convolution_1(x_1.contiguous())  # [N, T1', E]
-
-        # filter out all tokens, that do not have any descendants in last layer
-        mask = (val_1.view(batch_size, -1, self.conv_size) == 2).max(dim=-1)[0]
-        len_out = torch.max(torch.sum(mask, dim=-1)).item()
-
-        x_masked = torch.zeros(batch_size, len_out, embedding.shape[2], dtype=torch.float, device=value.device)
-        val_masked = torch.zeros((batch_size, len_out), dtype=torch.long, device=value.device)
-        for i in range(batch_size):
-            x_masked[i] = x_out[i, mask[i].nonzero().squeeze(-1)]
-            val_masked[i] = val_1[i, ::self.conv_size][mask[i].nonzero().squeeze(-1)]
-
-        # precompute padding mask
-        self.mask = padding_mask(val_masked, device=value.device)
-
-        return x_masked
+        return self.convolution_1(x_1.contiguous())  # [N, T1', E]
 
     def forward(self, value, depth, position):
         """ Transform sequences into embedding space for the encoder.
-- 
GitLab