diff --git a/modules/generative_head/double_substitution_head.py b/modules/generative_head/double_substitution_head.py index 6e42d274376ecb4431dd68202cd12eb85c6cd852..e1d41dda06ac2a57fceb31a122f22c3ccd2cc2fb 100644 --- a/modules/generative_head/double_substitution_head.py +++ b/modules/generative_head/double_substitution_head.py @@ -51,6 +51,9 @@ class DoubleSubstitutionHead(nn.Module): convolution_0 += [nn.GELU(), BlockConvolution(head_dim, head_dim, 8)] self.convolution_0 = nn.Sequential(*convolution_0) + self.down_convolution_1 = Convolution(head_dim, head_dim, 8) + self.down_convolution_0 = Convolution(head_dim, head_dim, 8) + linear = [] for i in range(n_layer - 1): linear += [nn.GELU(), nn.Linear(head_dim, head_dim)] @@ -102,16 +105,17 @@ class DoubleSubstitutionHead(nn.Module): # add spatial decoding if available if self.spatial_encoding is not None: emb_0 = emb_0 + self.spatial_encoding(pos[:, -len_0:]) - emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8]) emb_1 = torch.zeros((batch_size, torch.max(len_1), self.head_dim), dtype=torch.float, device=value.device) # substitute all mixed token embeddings of penultimate layer, with token embeddings of last layer - emb_1[val_1 == 2] = emb_0[:, 7::8] # [N, T1, C] - emb_1 = self.convolution_1(emb_1) + emb_1[val_1 == 2] = self.down_convolution_0(emb_0[:, :mix_1 * 8]) # [N, T1, C] emb_2 = torch.zeros((batch_size, torch.max(len_2), self.head_dim), dtype=torch.float, device=value.device) # substitute all mixed token embeddings of third to last layer, with token embeddings of penultimate layer - emb_2[val_2 == 2] = emb_1[:, 7::8] # [N, T1, C] + emb_2[val_2 == 2] = self.down_convolution_1(emb_1) # [N, T1, C] + + emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8]) + emb_1 = self.convolution_1(emb_1) emb_2 = self.convolution_2(emb_2) # create intermediate list to hold vectors diff --git a/modules/generative_head/substitution_head.py b/modules/generative_head/substitution_head.py index e5474bca84787803fb6b34b24c49bf8f98260285..b4a7fd63eb91ec24377aeb99c65aba53c11bbca5 100644 --- a/modules/generative_head/substitution_head.py +++ b/modules/generative_head/substitution_head.py @@ -41,6 +41,8 @@ class SubstitutionHead(nn.Module): convolution_0 += [nn.GELU(), BlockConvolution(head_dim, head_dim, 8)] self.convolution_0 = nn.Sequential(*convolution_0) + self.down_convolution = Convolution(head_dim, head_dim, 8) + linear = [] for i in range(n_layer - 1): linear += [nn.GELU(), nn.Linear(head_dim, head_dim)] @@ -90,11 +92,12 @@ class SubstitutionHead(nn.Module): # add spatial decoding if available if self.spatial_encoding is not None: emb_0 = emb_0 + self.spatial_encoding(pos[:, -len_0:]) - emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8]) emb_1 = torch.zeros((batch_size, torch.max(len_1), self.head_dim), dtype=torch.float, device=value.device) # substitite all mixed token embeddings of penultimate layer, with token embeddings of last layer - emb_1[val_1 == 2] = emb_0[:, 7::8] # [N, T1, C] + emb_1[val_1 == 2] = self.down_convolution(emb_0[:, :mix_1 * 8]) # [N, T1, C] + + emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8]) emb_1 = self.convolution_1(emb_1) x_0 = torch.zeros((batch_size, torch.max(mix_1), self.head_dim), device=value.device)