Skip to content
Snippets Groups Projects
Commit d90b6609 authored by Moritz Ibing's avatar Moritz Ibing
Browse files

Simple fix to pass full information to higher layers

parent aa25ef26
No related branches found
No related tags found
No related merge requests found
......@@ -51,6 +51,9 @@ class DoubleSubstitutionHead(nn.Module):
convolution_0 += [nn.GELU(), BlockConvolution(head_dim, head_dim, 8)]
self.convolution_0 = nn.Sequential(*convolution_0)
self.down_convolution_1 = Convolution(head_dim, head_dim, 8)
self.down_convolution_0 = Convolution(head_dim, head_dim, 8)
linear = []
for i in range(n_layer - 1):
linear += [nn.GELU(), nn.Linear(head_dim, head_dim)]
......@@ -102,16 +105,17 @@ class DoubleSubstitutionHead(nn.Module):
# add spatial decoding if available
if self.spatial_encoding is not None:
emb_0 = emb_0 + self.spatial_encoding(pos[:, -len_0:])
emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8])
emb_1 = torch.zeros((batch_size, torch.max(len_1), self.head_dim), dtype=torch.float, device=value.device)
# substitute all mixed token embeddings of penultimate layer, with token embeddings of last layer
emb_1[val_1 == 2] = emb_0[:, 7::8] # [N, T1, C]
emb_1 = self.convolution_1(emb_1)
emb_1[val_1 == 2] = self.down_convolution_0(emb_0[:, :mix_1 * 8]) # [N, T1, C]
emb_2 = torch.zeros((batch_size, torch.max(len_2), self.head_dim), dtype=torch.float, device=value.device)
# substitute all mixed token embeddings of third to last layer, with token embeddings of penultimate layer
emb_2[val_2 == 2] = emb_1[:, 7::8] # [N, T1, C]
emb_2[val_2 == 2] = self.down_convolution_1(emb_1) # [N, T1, C]
emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8])
emb_1 = self.convolution_1(emb_1)
emb_2 = self.convolution_2(emb_2)
# create intermediate list to hold vectors
......
......@@ -41,6 +41,8 @@ class SubstitutionHead(nn.Module):
convolution_0 += [nn.GELU(), BlockConvolution(head_dim, head_dim, 8)]
self.convolution_0 = nn.Sequential(*convolution_0)
self.down_convolution = Convolution(head_dim, head_dim, 8)
linear = []
for i in range(n_layer - 1):
linear += [nn.GELU(), nn.Linear(head_dim, head_dim)]
......@@ -90,11 +92,12 @@ class SubstitutionHead(nn.Module):
# add spatial decoding if available
if self.spatial_encoding is not None:
emb_0 = emb_0 + self.spatial_encoding(pos[:, -len_0:])
emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8])
emb_1 = torch.zeros((batch_size, torch.max(len_1), self.head_dim), dtype=torch.float, device=value.device)
# substitite all mixed token embeddings of penultimate layer, with token embeddings of last layer
emb_1[val_1 == 2] = emb_0[:, 7::8] # [N, T1, C]
emb_1[val_1 == 2] = self.down_convolution(emb_0[:, :mix_1 * 8]) # [N, T1, C]
emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8])
emb_1 = self.convolution_1(emb_1)
x_0 = torch.zeros((batch_size, torch.max(mix_1), self.head_dim), device=value.device)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment