diff --git a/modules/generative_head/double_substitution_head.py b/modules/generative_head/double_substitution_head.py
index 6e42d274376ecb4431dd68202cd12eb85c6cd852..e1d41dda06ac2a57fceb31a122f22c3ccd2cc2fb 100644
--- a/modules/generative_head/double_substitution_head.py
+++ b/modules/generative_head/double_substitution_head.py
@@ -51,6 +51,9 @@ class DoubleSubstitutionHead(nn.Module):
             convolution_0 += [nn.GELU(), BlockConvolution(head_dim, head_dim, 8)]
         self.convolution_0 = nn.Sequential(*convolution_0)
 
+        self.down_convolution_1 = Convolution(head_dim, head_dim, 8)
+        self.down_convolution_0 = Convolution(head_dim, head_dim, 8)
+
         linear = []
         for i in range(n_layer - 1):
             linear += [nn.GELU(), nn.Linear(head_dim, head_dim)]
@@ -102,16 +105,17 @@ class DoubleSubstitutionHead(nn.Module):
         # add spatial decoding if available
         if self.spatial_encoding is not None:
             emb_0 = emb_0 + self.spatial_encoding(pos[:, -len_0:])
-        emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8])
 
         emb_1 = torch.zeros((batch_size, torch.max(len_1), self.head_dim), dtype=torch.float, device=value.device)
         # substitute all mixed token embeddings of penultimate layer, with token embeddings of last layer
-        emb_1[val_1 == 2] = emb_0[:, 7::8]  # [N, T1, C]
-        emb_1 = self.convolution_1(emb_1)
+        emb_1[val_1 == 2] = self.down_convolution_0(emb_0[:, :mix_1 * 8])  # [N, T1, C]
 
         emb_2 = torch.zeros((batch_size, torch.max(len_2), self.head_dim), dtype=torch.float, device=value.device)
         # substitute all mixed token embeddings of third to last layer, with token embeddings of penultimate layer
-        emb_2[val_2 == 2] = emb_1[:, 7::8]  # [N, T1, C]
+        emb_2[val_2 == 2] = self.down_convolution_1(emb_1)  # [N, T1, C]
+
+        emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8])
+        emb_1 = self.convolution_1(emb_1)
         emb_2 = self.convolution_2(emb_2)
 
         # create intermediate list to hold vectors
diff --git a/modules/generative_head/substitution_head.py b/modules/generative_head/substitution_head.py
index e5474bca84787803fb6b34b24c49bf8f98260285..b4a7fd63eb91ec24377aeb99c65aba53c11bbca5 100644
--- a/modules/generative_head/substitution_head.py
+++ b/modules/generative_head/substitution_head.py
@@ -41,6 +41,8 @@ class SubstitutionHead(nn.Module):
             convolution_0 += [nn.GELU(), BlockConvolution(head_dim, head_dim, 8)]
         self.convolution_0 = nn.Sequential(*convolution_0)
 
+        self.down_convolution = Convolution(head_dim, head_dim, 8)
+
         linear = []
         for i in range(n_layer - 1):
             linear += [nn.GELU(), nn.Linear(head_dim, head_dim)]
@@ -90,11 +92,12 @@ class SubstitutionHead(nn.Module):
         # add spatial decoding if available
         if self.spatial_encoding is not None:
             emb_0 = emb_0 + self.spatial_encoding(pos[:, -len_0:])
-        emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8])
 
         emb_1 = torch.zeros((batch_size, torch.max(len_1), self.head_dim), dtype=torch.float, device=value.device)
         # substitite all mixed token embeddings of penultimate layer, with token embeddings of last layer
-        emb_1[val_1 == 2] = emb_0[:, 7::8]  # [N, T1, C]
+        emb_1[val_1 == 2] = self.down_convolution(emb_0[:, :mix_1 * 8])  # [N, T1, C]
+
+        emb_0 = self.convolution_0(emb_0[:, :mix_1 * 8])
         emb_1 = self.convolution_1(emb_1)
 
         x_0 = torch.zeros((batch_size, torch.max(mix_1), self.head_dim), device=value.device)