diff --git a/modules/generative_head/composite_head_A.py b/modules/generative_head/composite_head_A.py
index 5d09c79b9b90c1dd980c2f32b9f31ae60e40e861..072c6ada1cec5eebb48a7a284aea3a8738dbd3de 100644
--- a/modules/generative_head/composite_head_A.py
+++ b/modules/generative_head/composite_head_A.py
@@ -9,7 +9,7 @@ from .substitution_head import SubstitutionHead, SubstitutionHeadAutoregressive
 
 
 class CompositeHeadA(nn.Module):
-    def __init__(self, spatial_encoding, num_vocab, embed_dim, resolution, spatial_dim, **_):
+    def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution, **_):
         """ Performs a transformation from transformer latent space into target value logits.
 
         Uses a different heads for each depth layer, possibly increasing the overall sequence lenght.
@@ -17,9 +17,10 @@ class CompositeHeadA(nn.Module):
 
         Args:
             num_vocab: Number of different target token values (exclusive padding token '0').
-            embded_dim: Dimension of the latent embedding space of the transformer.
+            embed_dim: Dimension of the latent embedding space of the transformer.
+            head_dim: Size of embedding dimensions used in the head layers.
+            n_layer: Number of layers used in each linear or convolution block.
             resolution: Spatial resolution of sequence encoding.
-            spatial_dim: Spatial dimension (2D/3D) of the sequence data.
         """
         super(CompositeHeadA, self).__init__()
 
@@ -27,7 +28,8 @@ class CompositeHeadA(nn.Module):
             "spatial_encoding": spatial_encoding,
             "num_vocab": num_vocab,
             "embed_dim": embed_dim,
-            "spatial_dim": spatial_dim,
+            "head_dim": head_dim,
+            "n_layer": n_layer
         }
 
         modules = []
@@ -38,15 +40,15 @@ class CompositeHeadA(nn.Module):
         if resolution >= 8:
             modules += [LinearHead(**kwargs)]
         if resolution >= 16:
-            modules += [ConvolutionHeadA(**kwargs, conv_size=2 ** (spatial_dim - 1))]
+            modules += [ConvolutionHeadA(**kwargs, conv_size=4)]
         if resolution >= 32:
-            modules += [ConvolutionHeadA(**kwargs, conv_size=2 ** spatial_dim)]
+            modules += [ConvolutionHeadA(**kwargs, conv_size=8)]
         if resolution >= 64:
-            modules += [SubstitutionHead(**kwargs, conv_size=2 ** spatial_dim)]
+            modules += [SubstitutionHead(**kwargs, conv_size=8)]
         if resolution >= 128:
-            modules += [DoubleSubstitutionHead(**kwargs, conv_size=2 ** spatial_dim)]
+            modules += [DoubleSubstitutionHead(**kwargs, conv_size=8)]
         if resolution >= 256:
-            modules += [DoubleSubstitutionHead(**kwargs, conv_size=2 ** spatial_dim)]
+            modules += [DoubleSubstitutionHead(**kwargs, conv_size=8)]
 
         # embeddings
         self.heads = nn.ModuleList(modules)
@@ -55,11 +57,11 @@ class CompositeHeadA(nn.Module):
             1: 1,
             2: 1,
             3: 1,
-            4: 2 ** (spatial_dim - 1),
-            5: 2 ** spatial_dim,
-            6: 2 ** spatial_dim,  # Note: 'substitution'
-            7: 2 ** spatial_dim,  # Note: 'double_substitution'
-            8: 2 ** spatial_dim,  # Note: 'double_substitution'
+            4: 4,
+            5: 8,
+            6: 8,  # Note: 'substitution'
+            7: 8,  # Note: 'double_substitution'
+            8: 8,  # Note: 'double_substitution'
         }
 
     def forward(self, x, value, depth, position):
@@ -154,7 +156,7 @@ class CompositeHeadA(nn.Module):
 
 
 class CompositeHeadAutoregressiveA(CompositeHeadA):
-    def __init__(self, spatial_encoding, num_vocab, embed_dim, resolution, spatial_dim, **_):
+    def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution, **_):
         """ Performs a transformation from transformer latent space into target value logits.
 
         Uses a different heads for each depth layer, possibly increasing the overall sequence lenght.
@@ -163,17 +165,19 @@ class CompositeHeadAutoregressiveA(CompositeHeadA):
         Args:
             num_vocab: Number of different target token values (exclusive padding token '0').
             embded_dim: Dimension of the latent embedding space of the transformer.
+            head_dim: Size of embedding dimensions used in the head layers.
+            n_layer: Number of layers used in each linear or convolution block.
             resolution: Spatial resolution of sequence encoding.
-            spatial_dim: Spatial dimension (2D/3D) of the sequence data.
         """
-        super(CompositeHeadAutoregressiveA, self).__init__(spatial_encoding, num_vocab, embed_dim, resolution,
-                                                           spatial_dim, **_)
+        super(CompositeHeadAutoregressiveA, self).__init__(spatial_encoding, num_vocab, embed_dim, head_dim, n_layer,
+                                                           resolution, **_)
 
         kwargs = {
             "spatial_encoding": spatial_encoding,
             "num_vocab": num_vocab,
             "embed_dim": embed_dim,
-            "spatial_dim": spatial_dim,
+            "head_dim": head_dim,
+            "n_layer": n_layer
         }
 
         modules = []
@@ -184,15 +188,15 @@ class CompositeHeadAutoregressiveA(CompositeHeadA):
         if resolution >= 8:
             modules += [LinearHead(**kwargs)]
         if resolution >= 16:
-            modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=2 ** (spatial_dim - 1))]
+            modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=4)]
         if resolution >= 32:
-            modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=2 ** spatial_dim)]
+            modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=8)]
         if resolution >= 64:
-            modules += [SubstitutionHeadAutoregressive(**kwargs, conv_size=2 ** spatial_dim)]
+            modules += [SubstitutionHeadAutoregressive(**kwargs, conv_size=8)]
         if resolution >= 128:
-            modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=2 ** spatial_dim)]
+            modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=8)]
         if resolution >= 256:
-            modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=2 ** spatial_dim)]
+            modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=8)]
 
         # embeddings
         self.heads = nn.ModuleList(modules)
@@ -201,9 +205,9 @@ class CompositeHeadAutoregressiveA(CompositeHeadA):
             1: 1,
             2: 1,
             3: 1,
-            4: 2 ** (spatial_dim - 1),
-            5: 2 ** spatial_dim,
-            6: 2 ** spatial_dim,  # Note: 'substitution'
-            7: 2 ** spatial_dim,  # Note: 'double_substitution'
-            8: 2 ** spatial_dim,  # Note: 'double_substitution'
+            4: 4,
+            5: 8,
+            6: 8,  # Note: 'substitution'
+            7: 8,  # Note: 'double_substitution'
+            8: 8,  # Note: 'double_substitution'
         }
diff --git a/modules/generative_head/composite_head_B.py b/modules/generative_head/composite_head_B.py
index b8eec0d472124fba321bb1697bfbdefc08dcf340..e980929ead74c2fc46e4ffd50be7164cb5c0ba2f 100644
--- a/modules/generative_head/composite_head_B.py
+++ b/modules/generative_head/composite_head_B.py
@@ -8,7 +8,7 @@ from .double_substitution_head import DoubleSubstitutionHead
 
 
 class CompositeHeadB(CompositeHeadA):
-    def __init__(self, spatial_encoding, num_vocab, embed_dim, resolution, spatial_dim, **_):
+    def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution, **_):
         """ Performs a transformation from transformer latent space into target value logits.
 
         Uses a different heads for each depth layer, possibly increasing the overall sequence lenght.
@@ -16,17 +16,19 @@ class CompositeHeadB(CompositeHeadA):
 
         Args:
             num_vocab: Number of different target token values (exclusive padding token '0').
-            embded_dim: Dimension of the latent embedding space of the transformer.
+            embed_dim: Dimension of the latent embedding space of the transformer.
+            head_dim: Size of embedding dimensions used in the head layers.
+            n_layer: Number of layers used in each linear or convolution block.
             resolution: Spatial resolution of sequence encoding.
-            spatial_dim: Spatial dimension (2D/3D) of the sequence data.
         """
-        super(CompositeHeadB, self).__init__(num_vocab, embed_dim, resolution, spatial_dim)
+        super(CompositeHeadB, self).__init__(spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution)
 
         kwargs = {
             "spatial_encoding": spatial_encoding,
             "num_vocab": num_vocab,
             "embed_dim": embed_dim,
-            "spatial_dim": spatial_dim,
+            "head_dim": head_dim,
+            "n_layer": n_layer
         }
 
         modules = []
@@ -37,13 +39,13 @@ class CompositeHeadB(CompositeHeadA):
         if resolution >= 8:
             modules += [LinearHead(**kwargs)]
         if resolution >= 16:
-            modules += [ConvolutionHeadA(**kwargs, conv_size=2**(spatial_dim - 2))]
+            modules += [ConvolutionHeadA(**kwargs, conv_size=4)]
         if resolution >= 32:
-            modules += [ConvolutionHeadA(**kwargs, conv_size=2**spatial_dim)]
+            modules += [ConvolutionHeadA(**kwargs, conv_size=8)]
         if resolution >= 64:
-            modules += [SubstitutionHead(**kwargs, conv_size=2**spatial_dim)]
+            modules += [SubstitutionHead(**kwargs, conv_size=8)]
         if resolution >= 128:
-            modules += [DoubleSubstitutionHead(**kwargs, conv_size=2**spatial_dim)]
+            modules += [DoubleSubstitutionHead(**kwargs, conv_size=8)]
 
         # embeddings
         self.heads = nn.ModuleList(modules)
@@ -52,8 +54,8 @@ class CompositeHeadB(CompositeHeadA):
             1: 1,
             2: 1,
             3: 1,
-            4: 2**(spatial_dim - 2),
-            5: 2**spatial_dim,
-            6: 2**spatial_dim,  # Note: 'substitution'
-            7: 2**spatial_dim,  # Note: 'double_substitution'
+            4: 4,
+            5: 8,
+            6: 8,  # Note: 'substitution'
+            7: 8,  # Note: 'double_substitution'
         }
diff --git a/modules/generative_head/composite_head_C.py b/modules/generative_head/composite_head_C.py
index 601be52d407c1871e6df68510ff2a73180f117e0..65effd645cb55984790df24e5719656f7889e479 100644
--- a/modules/generative_head/composite_head_C.py
+++ b/modules/generative_head/composite_head_C.py
@@ -9,7 +9,7 @@ from .double_substitution_head import DoubleSubstitutionHead
 
 
 class CompositeHeadC(nn.Module):
-    def __init__(self, spatial_encoding, num_vocab, embed_dim, resolution, spatial_dim, **_):
+    def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution, **_):
         """ Performs a transformation from transformer latent space into target value logits.
 
         Uses a different heads for each depth layer, possibly increasing the overall sequence lenght.
@@ -17,9 +17,10 @@ class CompositeHeadC(nn.Module):
 
         Args:
             num_vocab: Number of different target token values (exclusive padding token '0').
-            embded_dim: Dimension of the latent embedding space of the transformer.
+            embed_dim: Dimension of the latent embedding space of the transformer.
+            head_dim: Size of embedding dimensions used in the head layers.
+            n_layer: Number of layers used in each linear or convolution block.
             resolution: Spatial resolution of sequence encoding.
-            spatial_dim: Spatial dimension (2D/3D) of the sequence data.
         """
         super(CompositeHeadC, self).__init__()
 
@@ -27,7 +28,8 @@ class CompositeHeadC(nn.Module):
             "spatial_encoding": spatial_encoding,
             "num_vocab": num_vocab,
             "embed_dim": embed_dim,
-            "spatial_dim": spatial_dim,
+            "head_dim": head_dim,
+            "n_layer": n_layer
         }
 
         modules = []
@@ -36,13 +38,13 @@ class CompositeHeadC(nn.Module):
         if resolution >= 4:
             modules += [LinearHead(**kwargs)]
         if resolution >= 8:
-            modules += [ConvolutionHeadA(**kwargs, conv_size=2**(spatial_dim - 1))]
+            modules += [ConvolutionHeadA(**kwargs, conv_size=4)]
         if resolution >= 16:
-            modules += [ConvolutionHeadA(**kwargs, conv_size=2**spatial_dim)]
+            modules += [ConvolutionHeadA(**kwargs, conv_size=8)]
         if resolution >= 32:
-            modules += [SubstitutionHead(**kwargs, conv_size=2**spatial_dim)]
+            modules += [SubstitutionHead(**kwargs, conv_size=8)]
         if resolution >= 64:
-            modules += [DoubleSubstitutionHead(**kwargs, conv_size=2**spatial_dim)]
+            modules += [DoubleSubstitutionHead(**kwargs, conv_size=8)]
 
         # embeddings
         self.heads = nn.ModuleList(modules)
@@ -50,10 +52,10 @@ class CompositeHeadC(nn.Module):
         self.reduction_factor = {
             1: 1,
             2: 1,
-            3: 2**(spatial_dim - 1),
-            4: 2**spatial_dim,
-            5: 2**spatial_dim,  # Note: 'substitution'
-            6: 2**spatial_dim,  # Note: 'double_substitution'
+            3: 4,
+            4: 8,
+            5: 8,  # Note: 'substitution'
+            6: 8,  # Note: 'double_substitution'
         }
 
     def forward(self, x, value, depth, position):
diff --git a/modules/generative_head/convolution_head_A.py b/modules/generative_head/convolution_head_A.py
index c090fb4b8daa80979fa2ce2b0ae98a36e49078c8..140321497fb066e530d2a4fab1ccb9fb08297a54 100644
--- a/modules/generative_head/convolution_head_A.py
+++ b/modules/generative_head/convolution_head_A.py
@@ -1,10 +1,10 @@
 import torch.nn as nn
 
-from ..utils import Deconvolution, BlockConvolution, Linear
+from ..utils import Deconvolution, Convolution, BlockConvolution, Linear
 
 
 class ConvolutionHeadA(nn.Module):
-    def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_):
+    def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_):
         """ Performs a convolutional transformation from transformer latent space into target value logits.
 
         Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
@@ -12,13 +12,24 @@ class ConvolutionHeadA(nn.Module):
         Args:
             num_vocab: Number of different target token values (exclusive padding token '0').
             embded_dim: Dimension of the latent embedding space of the transformer.
+            head_dim: Size of embedding dimensions used in the head layers.
+            n_layer: Number of layers used in each linear or convolution block.
             spatial_dim: Spatial dimension (2D/3D) of the sequence data.
             conv_size: Convolution kernel size and stride.
         """
         super(ConvolutionHeadA, self).__init__()
 
-        self.deconvolution = Deconvolution(embed_dim, embed_dim, conv_size)
-        self.linear = Linear(embed_dim, num_vocab)
+        deconvolution = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)]
+        for i in range(n_layer - 1):
+            deconvolution += [nn.GELU(), Convolution(head_dim, head_dim, (1,))]
+        self.deconvolution = nn.Sequential(*deconvolution)
+
+        linear = []
+        for i in range(n_layer - 1):
+            linear += [nn.GELU(), nn.Linear(head_dim, head_dim)]
+        linear += [nn.GELU(), Linear(head_dim, num_vocab)]
+        self.linear = nn.Sequential(*linear)
+
         self.spatial_encoding = spatial_encoding
 
     def forward(self, x, value, depth, pos):
@@ -45,7 +56,7 @@ class ConvolutionHeadA(nn.Module):
 
 
 class ConvolutionHeadAutoregressive(nn.Module):
-    def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_):
+    def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_):
         """ Performs a convolutional transformation from transformer latent space into target value logits.
 
         Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
@@ -53,17 +64,33 @@ class ConvolutionHeadAutoregressive(nn.Module):
         Args:
             num_vocab: Number of different target token values (exclusive padding token '0').
             embded_dim: Dimension of the latent embedding space of the transformer.
+            head_dim: Size of embedding dimensions used in the head layers.
+            n_layer: Number of layers used in each linear or convolution block.
             spatial_dim: Spatial dimension (2D/3D) of the sequence data.
             conv_size: Convolution kernel size and stride.
         """
         super(ConvolutionHeadAutoregressive, self).__init__()
 
         self.conv_size = conv_size
-        self.deconvolution = Deconvolution(embed_dim, embed_dim, conv_size)
-        self.convolution = BlockConvolution(embed_dim, embed_dim, conv_size)
-        self.linear = Linear(embed_dim, num_vocab)
+
+        deconvolution = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)]
+        for i in range(n_layer - 1):
+            deconvolution += [nn.GELU(), Convolution(head_dim, head_dim, (1,))]
+        self.deconvolution = nn.Sequential(*deconvolution)
+
+        convolution = [BlockConvolution(head_dim, head_dim, conv_size)]
+        for i in range(n_layer - 1):
+            convolution += [nn.GELU(), BlockConvolution(head_dim, head_dim, conv_size)]
+        self.convolution = nn.Sequential(*convolution)
+
+        linear = []
+        for i in range(n_layer - 1):
+            linear += [nn.GELU(), nn.Linear(head_dim, head_dim)]
+        linear += [nn.GELU(), Linear(head_dim, num_vocab)]
+        self.linear = nn.Sequential(*linear)
+
         self.spatial_encoding = spatial_encoding
-        self.value_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0)
+        self.value_embedding = nn.Embedding(num_vocab + 1, head_dim, padding_idx=0)
 
     def forward(self, x, value, depth, pos):
         """ Transforms the output of the transformer target value logits.
diff --git a/modules/generative_head/double_substitution_head.py b/modules/generative_head/double_substitution_head.py
index e9064e0ee719f56439f6b9c14dacca3e09d9b7f4..2807f616f2f4c9ed809a878b86b487841a999ef2 100644
--- a/modules/generative_head/double_substitution_head.py
+++ b/modules/generative_head/double_substitution_head.py
@@ -1,11 +1,11 @@
 import torch
 import torch.nn as nn
 
-from ..utils import Deconvolution, BlockConvolution, Linear
+from ..utils import Deconvolution, Convolution, BlockConvolution, Linear
 
 
 class DoubleSubstitutionHead(nn.Module):
-    def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_):
+    def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_):
         """ Performs a twice a substitution transformation from transformer latent space into target value logits.
 
         Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
@@ -13,20 +13,37 @@ class DoubleSubstitutionHead(nn.Module):
         Args:
             num_vocab: Number of different target token values (exclusive padding token '0').
             embded_dim: Dimension of the latent embedding space of the transformer.
+            head_dim: Size of embedding dimensions used in the head layers.
+            n_layer: Number of layers used in each linear or convolution block.
             spatial_dim: Spatial dimension (2D/3D) of the sequence data.
             conv_size: Convolution kernel size and stride.
         """
         super(DoubleSubstitutionHead, self).__init__()
-        self.embed_dim = embed_dim
+        self.head_dim = head_dim
 
         # deconvolutions
-        self.deconvolution_2 = Deconvolution(embed_dim, embed_dim, conv_size)
-        self.deconvolution_1 = Deconvolution(embed_dim, embed_dim, conv_size)
-        self.deconvolution_0 = Deconvolution(embed_dim, embed_dim, conv_size)
-        self.spatial_encoding = spatial_encoding
+        deconvolution_2 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)]
+        for i in range(n_layer - 1):
+            deconvolution_2 += [nn.GELU(), Convolution(head_dim, head_dim, 1)]
+        self.deconvolution_2 = nn.Sequential(*deconvolution_2)
+
+        deconvolution_1 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)]
+        for i in range(n_layer - 1):
+            deconvolution_1 += [nn.GELU(), Convolution(head_dim, head_dim, 1)]
+        self.deconvolution_1 = nn.Sequential(*deconvolution_1)
+
+        deconvolution_0 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)]
+        for i in range(n_layer - 1):
+            deconvolution_0 += [nn.GELU(), Convolution(head_dim, head_dim, 1)]
+        self.deconvolution_0 = nn.Sequential(*deconvolution_0)
+
+        linear = []
+        for i in range(n_layer - 1):
+            linear += [nn.GELU(), nn.Linear(head_dim, head_dim)]
+        linear += [nn.GELU(), Linear(head_dim, num_vocab)]
+        self.linear = nn.Sequential(*linear)
 
-        # head
-        self.linear = Linear(embed_dim, num_vocab)
+        self.spatial_encoding = spatial_encoding
 
     def forward(self, x, value, depth, pos):
         """ Transforms the output of the transformer target value logits.
@@ -64,8 +81,8 @@ class DoubleSubstitutionHead(nn.Module):
         mix_2 = torch.sum(val_2 == 2, dim=1)
 
         # create intermediate list to hold vectors
-        x_0 = torch.zeros((batch_size, torch.max(mix_1), self.embed_dim), device=value.device)
-        x_1 = torch.zeros((batch_size, torch.max(mix_2), self.embed_dim), device=value.device)
+        x_0 = torch.zeros((batch_size, torch.max(mix_1), self.head_dim), device=value.device)
+        x_1 = torch.zeros((batch_size, torch.max(mix_2), self.head_dim), device=value.device)
 
         # deconvolute the latent space - sequence length equals number of tokens in the penultimate layer
         y_2 = self.deconvolution_2(x)
@@ -95,7 +112,7 @@ class DoubleSubstitutionHead(nn.Module):
 
 
 class DoubleSubstitutionHeadAutoRegressive(nn.Module):
-    def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_):
+    def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_):
         """ Performs a twice a substitution transformation from transformer latent space into target value logits.
 
         Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
@@ -103,26 +120,53 @@ class DoubleSubstitutionHeadAutoRegressive(nn.Module):
         Args:
             num_vocab: Number of different target token values (exclusive padding token '0').
             embded_dim: Dimension of the latent embedding space of the transformer.
+            head_dim: Size of embedding dimensions used in the head layers.
+            n_layer: Number of layers used in each linear or convolution block.
             spatial_dim: Spatial dimension (2D/3D) of the sequence data.
             conv_size: Convolution kernel size and stride.
         """
         super(DoubleSubstitutionHeadAutoRegressive, self).__init__()
-        self.embed_dim = embed_dim
+        self.head_dim = head_dim
         self.conv_size = conv_size
 
-        # deconvolutions
-        self.deconvolution_2 = Deconvolution(embed_dim, embed_dim, conv_size)
-        self.deconvolution_1 = Deconvolution(embed_dim, embed_dim, conv_size)
-        self.deconvolution_0 = Deconvolution(embed_dim, embed_dim, conv_size)
+        deconvolution_2 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)]
+        for i in range(n_layer - 1):
+            deconvolution_2 += [nn.GELU(), Convolution(head_dim, head_dim, 1)]
+        self.deconvolution_2 = nn.Sequential(*deconvolution_2)
+
+        deconvolution_1 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)]
+        for i in range(n_layer - 1):
+            deconvolution_1 += [nn.GELU(), Convolution(head_dim, head_dim, 1)]
+        self.deconvolution_1 = nn.Sequential(*deconvolution_1)
+
+        deconvolution_0 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)]
+        for i in range(n_layer - 1):
+            deconvolution_0 += [nn.GELU(), Convolution(head_dim, head_dim, 1)]
+        self.deconvolution_0 = nn.Sequential(*deconvolution_0)
+
+        convolution_2 = []
+        for i in range(n_layer):
+            convolution_2 += [nn.GELU(), BlockConvolution(head_dim, head_dim, conv_size)]
+        self.convolution_2 = nn.Sequential(*convolution_2)
+
+        convolution_1 = []
+        for i in range(n_layer):
+            convolution_1 += [nn.GELU(), BlockConvolution(head_dim, head_dim, 8)]
+        self.convolution_1 = nn.Sequential(*convolution_1)
+
+        convolution_0 = [BlockConvolution(head_dim, head_dim, 8)]
+        for i in range(n_layer - 1):
+            convolution_0 += [nn.GELU(), BlockConvolution(head_dim, head_dim, 8)]
+        self.convolution_0 = nn.Sequential(*convolution_0)
+
+        linear = []
+        for i in range(n_layer - 1):
+            linear += [nn.GELU(), nn.Linear(head_dim, head_dim)]
+        linear += [nn.GELU(), Linear(head_dim, num_vocab)]
+        self.linear = nn.Sequential(*linear)
 
-        self.convolution_2 = BlockConvolution(embed_dim, embed_dim, conv_size)
-        self.convolution_1 = BlockConvolution(embed_dim, embed_dim, conv_size)
-        self.convolution_0 = BlockConvolution(embed_dim, embed_dim, conv_size)
-
-        # head
-        self.linear = Linear(embed_dim, num_vocab)
         self.spatial_encoding = spatial_encoding
-        self.value_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0)
+        self.value_embedding = nn.Embedding(num_vocab + 1, head_dim, padding_idx=0)
 
     def forward(self, x, value, depth, pos):
         """ Transforms the output of the transformer target value logits.
@@ -168,19 +212,19 @@ class DoubleSubstitutionHeadAutoRegressive(nn.Module):
             emb_0 = emb_0 + self.spatial_encoding(pos[:, -len_0:])
         emb_0 = self.convolution_0(emb_0)
 
-        emb_1 = torch.zeros((batch_size, torch.max(len_1), self.embed_dim), dtype=torch.float, device=value.device)
+        emb_1 = torch.zeros((batch_size, torch.max(len_1), self.head_dim), dtype=torch.float, device=value.device)
         # substitute all mixed token embeddings of penultimate layer, with token embeddings of last layer
         emb_1[val_1 == 2] = emb_0[:, (self.conv_size - 1)::self.conv_size]  # [N, T1, C]
         emb_1 = self.convolution_1(emb_1)
 
-        emb_2 = torch.zeros((batch_size, torch.max(len_2), self.embed_dim), dtype=torch.float, device=value.device)
+        emb_2 = torch.zeros((batch_size, torch.max(len_2), self.head_dim), dtype=torch.float, device=value.device)
         # substitute all mixed token embeddings of third to last layer, with token embeddings of penultimate layer
         emb_2[val_2 == 2] = emb_1[:, (self.conv_size - 1)::self.conv_size]  # [N, T1, C]
         emb_2 = self.convolution_2(emb_2)
 
         # create intermediate list to hold vectors
-        x_0 = torch.zeros((batch_size, torch.max(mix_1), self.embed_dim), device=value.device)
-        x_1 = torch.zeros((batch_size, torch.max(mix_2), self.embed_dim), device=value.device)
+        x_0 = torch.zeros((batch_size, torch.max(mix_1), self.head_dim), device=value.device)
+        x_1 = torch.zeros((batch_size, torch.max(mix_2), self.head_dim), device=value.device)
 
         # deconvolute the latent space - sequence length equals number of tokens in the penultimate layer
         y_2 = self.deconvolution_2(x)
@@ -205,7 +249,7 @@ class DoubleSubstitutionHeadAutoRegressive(nn.Module):
         # add spatial decoding if available
         if self.spatial_encoding is not None:
             len_last = torch.sum(depth == max_depth, dim=1)
-            assert((depth[:, -len_last:] == max_depth).all())
+            assert ((depth[:, -len_last:] == max_depth).all())
             y_0 = y_0 + self.spatial_encoding(pos[:, -len_last:])
 
         # compute logits of generated tokens
diff --git a/modules/generative_head/head_factory.py b/modules/generative_head/head_factory.py
index 3d48aa4f490b070b6a108f0716dae3a7f8a89358..52bbd1422e2bfaaabef35ddba97b3075336cad3c 100644
--- a/modules/generative_head/head_factory.py
+++ b/modules/generative_head/head_factory.py
@@ -12,7 +12,7 @@ from .multi_conv_head_A import MultiConvolutionHeadA
 from .substitution_head import SubstitutionHead
 
 
-def _create_head(name, positional_encoding, num_vocab, embed_dim, resolution, spatial_dim):
+def _create_head(name, positional_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution):
     """ Creates a generative head.
 
     If the module specified in `name` does not exist raises a value error.
@@ -22,8 +22,9 @@ def _create_head(name, positional_encoding, num_vocab, embed_dim, resolution, sp
         name: Defines which generative head will be created.
         num_vocab: Number of different vocabs in the vocabulary set.
         embed_dim: Size of embedding dimensions used by the transformer model.
+        head_dim: Size of embedding dimensions used in the head layers.
+        n_layer: Number of layers used in each linear or convolution block.
         resolution: Spatial resolution of sequence encoding.
-        spatial_dim: Spatial dimensionality of input data.
 
     Return:
         Generative head initialised with specified parameters.
@@ -32,11 +33,11 @@ def _create_head(name, positional_encoding, num_vocab, embed_dim, resolution, sp
     if positional_encoding == 'None':
         spatial_encoding = None
     elif positional_encoding == 'basic':
-        spatial_encoding = PositionalEncodingLearned(embed_dim, resolution, spatial_dim)
+        spatial_encoding = PositionalEncodingLearned(head_dim, resolution)
     elif positional_encoding == 'look_ahead':
-        spatial_encoding = PositionalEncodingLearnedLookAhead(embed_dim, resolution, spatial_dim)
+        spatial_encoding = PositionalEncodingLearnedLookAhead(head_dim, resolution)
     elif positional_encoding == 'look_ahead_split':
-        spatial_encoding = PositionalEncodingLearnedLookAheadSplit(embed_dim, resolution, spatial_dim)
+        spatial_encoding = PositionalEncodingLearnedLookAheadSplit(head_dim, resolution)
     else:
         raise ValueError(f"ERROR: {positional_encoding} encoding not implemented.")
 
@@ -44,18 +45,19 @@ def _create_head(name, positional_encoding, num_vocab, embed_dim, resolution, sp
         "spatial_encoding": spatial_encoding,
         "num_vocab": num_vocab,
         "embed_dim": embed_dim,
+        "head_dim": head_dim,
+        "n_layer": n_layer,
         "resolution": resolution,
-        "spatial_dim": spatial_dim,
-        "conv_size": 2 ** spatial_dim,
+        "conv_size": 8,
     }
 
     if name in ('generative_basic', 'linear', 'basic'):
         return LinearHead(**kwargs)
     elif name == 'discrete_transformation':
-        kwargs["num_vocab"] = num_vocab ** 2 ** spatial_dim
+        kwargs["num_vocab"] = num_vocab ** 2 ** 3
         return LinearHead(**kwargs)
     elif name in ('half_conv', 'half_conv_A'):
-        kwargs["conv_size"] = 2 ** (spatial_dim - 1)
+        kwargs["conv_size"] = 2 ** (3 - 1)
         return ConvolutionHeadA(**kwargs)
     elif name in ('single_conv', 'single_conv_A'):
         return ConvolutionHeadA(**kwargs)
@@ -77,7 +79,7 @@ def _create_head(name, positional_encoding, num_vocab, embed_dim, resolution, sp
         raise ValueError(f"ERROR: {name} head not implemented.")
 
 
-def create_head(name, positional_encoding, num_vocab, embed_dim, resolution, spatial_dim):
+def create_head(name, positional_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution):
     """ Creates a generative head.
 
     If `name` is a list, creates a list of heads for each element of the list, otherwise a single one. If the module
@@ -89,14 +91,16 @@ def create_head(name, positional_encoding, num_vocab, embed_dim, resolution, spa
         name: Defines which generative head will be created.
         num_vocab: Number of different vocabs in the vocabulary set.
         embed_dim: Size of embedding dimensions used by the transformer model.
+        head_dim: Size of embedding dimensions used in the head layers.
+        n_layer: Number of layers used in each linear or convolution block.
         resolution: Spatial resolution of sequence encoding.
-        spatial_dim: Spatial dimensionality of input data.
 
     Return:
         Generative head or a list of heads initialised with specified parameters.
     """
     if type(name) == list:
         return nn.ModuleList(
-            [_create_head(n, positional_encoding, num_vocab, embed_dim, resolution, spatial_dim) for n in name])
+            [_create_head(n, positional_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution) for n
+             in name])
     else:
-        return _create_head(name, positional_encoding, num_vocab, embed_dim, resolution, spatial_dim)
+        return _create_head(name, positional_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution)
diff --git a/modules/generative_head/linear_head.py b/modules/generative_head/linear_head.py
index 98589fc8ae7f512a97e6f8e7ec1dc0ef23f12bf7..de6c21457c346b66ffebecd76a510ff276aa0d2f 100644
--- a/modules/generative_head/linear_head.py
+++ b/modules/generative_head/linear_head.py
@@ -4,7 +4,7 @@ from ..utils import Linear
 
 
 class LinearHead(nn.Module):
-    def __init__(self, spatial_encoding, num_vocab, embed_dim, **_):
+    def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, **_):
         """ Performs a linear transformation from transformer latent space into target value logits.
 
         Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
@@ -12,10 +12,20 @@ class LinearHead(nn.Module):
         Args:
             num_vocab: Number of different target token values (exclusive padding token '0').
             embded_dim: Dimension of the latent embedding space of the transformer.
+            head_dim: Size of embedding dimensions used in the head layers.
+            n_layer: Number of layers used in each linear or convolution block.
         """
         super(LinearHead, self).__init__()
 
-        self.linear = Linear(embed_dim, num_vocab)
+        if n_layer > 1:
+            linear = [nn.GELU(), nn.Linear(embed_dim, head_dim)]
+            for i in range(n_layer - 2):
+                linear += [nn.GELU(), nn.Linear(head_dim, head_dim)]
+            linear += [nn.GELU(), Linear(head_dim, num_vocab)]
+            self.linear = nn.Sequential(*linear)
+        else:
+            self.linear = Linear(embed_dim, num_vocab)
+
         self.spatial_encoding = spatial_encoding
 
     def forward(self, x, value, depth, pos):
diff --git a/modules/generative_head/substitution_head.py b/modules/generative_head/substitution_head.py
index 187b2f16943e0cca4504d22d281f9adc3d973d68..4f74d09b8267fdd5d2d9b0e075eb023f31509cb6 100644
--- a/modules/generative_head/substitution_head.py
+++ b/modules/generative_head/substitution_head.py
@@ -1,11 +1,11 @@
 import torch
 import torch.nn as nn
 
-from ..utils import BlockConvolution, Deconvolution, Linear
+from ..utils import Convolution, BlockConvolution, Deconvolution, Linear
 
 
 class SubstitutionHead(nn.Module):
-    def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_):
+    def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_):
         """ Performs a substitution transformation from transformer latent space into target value logits.
 
         Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
@@ -13,15 +13,30 @@ class SubstitutionHead(nn.Module):
         Args:
             num_vocab: Number of different target token values (exclusive padding token '0').
             embded_dim: Dimension of the latent embedding space of the transformer.
+            head_dim: Size of embedding dimensions used in the head layers.
+            n_layer: Number of layers used in each linear or convolution block.
             spatial_dim: Spatial dimension (2D/3D) of the sequence data.
             conv_size: Convolution kernel size and stride.
         """
         super(SubstitutionHead, self).__init__()
-        self.embed_dim = embed_dim
+        self.head_dim = head_dim
+
+        deconvolution_1 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)]
+        for i in range(n_layer - 1):
+            deconvolution_1 += [nn.GELU(), Convolution(head_dim, head_dim, 1)]
+        self.deconvolution_1 = nn.Sequential(*deconvolution_1)
+
+        deconvolution_0 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)]
+        for i in range(n_layer - 1):
+            deconvolution_0 += [nn.GELU(), Convolution(head_dim, head_dim, 1)]
+        self.deconvolution_0 = nn.Sequential(*deconvolution_0)
+
+        linear = []
+        for i in range(n_layer - 1):
+            linear += [nn.GELU(), nn.Linear(head_dim, head_dim)]
+        linear += [nn.GELU(), Linear(head_dim, num_vocab)]
+        self.linear = nn.Sequential(*linear)
 
-        self.deconvolution_1 = Deconvolution(embed_dim, embed_dim, conv_size)
-        self.deconvolution_0 = Deconvolution(embed_dim, embed_dim, conv_size)
-        self.linear = Linear(embed_dim, num_vocab)
         self.spatial_encoding = spatial_encoding
 
     def forward(self, x, value, depth, pos):
@@ -56,7 +71,7 @@ class SubstitutionHead(nn.Module):
         mix_1 = torch.sum(val_1 == 2, dim=1)
 
         # create intermediate list to hold vectors
-        x_0 = torch.zeros((batch_size, torch.max(mix_1), self.embed_dim), device=value.device)
+        x_0 = torch.zeros((batch_size, torch.max(mix_1), self.head_dim), device=value.device)
 
         # deconvolute the latent space - sequence length equals number of tokens in the penultimate layer
         y_1 = self.deconvolution_1(x)
@@ -79,7 +94,7 @@ class SubstitutionHead(nn.Module):
 
 
 class SubstitutionHeadAutoregressive(nn.Module):
-    def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_):
+    def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_):
         """ Performs a substitution transformation from transformer latent space into target value logits.
 
         Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
@@ -87,22 +102,43 @@ class SubstitutionHeadAutoregressive(nn.Module):
         Args:
             num_vocab: Number of different target token values (exclusive padding token '0').
             embded_dim: Dimension of the latent embedding space of the transformer.
+            head_dim: Size of embedding dimensions used in the head layers.
+            n_layer: Number of layers used in each linear or convolution block.
             spatial_dim: Spatial dimension (2D/3D) of the sequence data.
             conv_size: Convolution kernel size and stride.
         """
         super(SubstitutionHeadAutoregressive, self).__init__()
-        self.embed_dim = embed_dim
+        self.head_dim = head_dim
         self.conv_size = conv_size
 
-        self.deconvolution_1 = Deconvolution(embed_dim, embed_dim, conv_size)
-        self.deconvolution_0 = Deconvolution(embed_dim, embed_dim, conv_size)
+        deconvolution_1 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)]
+        for i in range(n_layer - 1):
+            deconvolution_1 += [nn.GELU(), Convolution(head_dim, head_dim, 1)]
+        self.deconvolution_1 = nn.Sequential(*deconvolution_1)
+
+        deconvolution_0 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)]
+        for i in range(n_layer - 1):
+            deconvolution_0 += [nn.GELU(), Convolution(head_dim, head_dim, 1)]
+        self.deconvolution_0 = nn.Sequential(*deconvolution_0)
+
+        convolution_1 = []
+        for i in range(n_layer):
+            convolution_1 += [nn.GELU(), BlockConvolution(head_dim, head_dim, conv_size)]
+        self.convolution_1 = nn.Sequential(*convolution_1)
+
+        convolution_0 = [BlockConvolution(head_dim, head_dim, 8)]
+        for i in range(n_layer - 1):
+            convolution_0 += [nn.GELU(), BlockConvolution(head_dim, head_dim, 8)]
+        self.convolution_0 = nn.Sequential(*convolution_0)
 
-        self.convolution_1 = BlockConvolution(embed_dim, embed_dim, conv_size)
-        self.convolution_0 = BlockConvolution(embed_dim, embed_dim, conv_size)
+        linear = []
+        for i in range(n_layer - 1):
+            linear += [nn.GELU(), nn.Linear(head_dim, head_dim)]
+        linear += [nn.GELU(), Linear(head_dim, num_vocab)]
+        self.linear = nn.Sequential(*linear)
 
-        self.linear = Linear(embed_dim, num_vocab)
         self.spatial_encoding = spatial_encoding
-        self.value_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0)
+        self.value_embedding = nn.Embedding(num_vocab + 1, head_dim, padding_idx=0)
 
     def forward(self, x, value, depth, pos):
         """ Transforms the output of the transformer target value logits.
@@ -146,12 +182,12 @@ class SubstitutionHeadAutoregressive(nn.Module):
             emb_0 = emb_0 + self.spatial_encoding(pos[:, -len_0:])
         emb_0 = self.convolution_0(emb_0)
 
-        emb_1 = torch.zeros((batch_size, torch.max(len_1), self.embed_dim), dtype=torch.float, device=value.device)
+        emb_1 = torch.zeros((batch_size, torch.max(len_1), self.head_dim), dtype=torch.float, device=value.device)
         # substitite all mixed token embeddings of penultimate layer, with token embeddings of last layer
         emb_1[val_1 == 2] = emb_0[:, (self.conv_size - 1)::self.conv_size]  # [N, T1, C]
         emb_1 = self.convolution_1(emb_1)
 
-        x_0 = torch.zeros((batch_size, torch.max(mix_1), self.embed_dim), device=value.device)
+        x_0 = torch.zeros((batch_size, torch.max(mix_1), self.head_dim), device=value.device)
 
         # deconvolute the latent space - sequence length equals number of tokens in the penultimate layer
         y_1 = self.deconvolution_1(x)
diff --git a/modules/shape_transformer.py b/modules/shape_transformer.py
index a3a60e91bbce6adeb416605d09ccbffd3c5defc6..a5c096fad061b2e6062b8b0464125d5e4772290d 100644
--- a/modules/shape_transformer.py
+++ b/modules/shape_transformer.py
@@ -46,6 +46,8 @@ class ShapeTransformer(pl.LightningModule):
     def __init__(
         self,
         embed_dim=16,
+        head_dim=16,
+        n_layer_head=1,
         num_heads=2,
         num_layers=8,
         num_positions=512,
@@ -88,8 +90,9 @@ class ShapeTransformer(pl.LightningModule):
             positional_encoding=head_pos_encoding,
             num_vocab=num_vocab,
             embed_dim=embed_dim,
+            head_dim=head_dim,
+            n_layer=n_layer_head,
             resolution=resolution,
-            spatial_dim=spatial_dim,
         )
 
         # transformer model
diff --git a/modules/token_embedding/embedding_factory.py b/modules/token_embedding/embedding_factory.py
index 2c5e4baeb5455a90ef3be49859f1e2d7f07a6932..dc6ff55e59d48bdff73851b49ff709e8c1462c73 100644
--- a/modules/token_embedding/embedding_factory.py
+++ b/modules/token_embedding/embedding_factory.py
@@ -30,11 +30,11 @@ def _create_embedding(name, positional_encoding, num_vocab, embed_dim, resolutio
     """
 
     if positional_encoding == 'basic':
-        spatial_encoding = PositionalEncodingLearned(embed_dim, resolution, spatial_dim)
+        spatial_encoding = PositionalEncodingLearned(embed_dim, resolution)
     elif positional_encoding == 'look_ahead':
-        spatial_encoding = PositionalEncodingLearnedLookAhead(embed_dim, resolution, spatial_dim)
+        spatial_encoding = PositionalEncodingLearnedLookAhead(embed_dim, resolution)
     elif positional_encoding == 'look_ahead_split':
-        spatial_encoding = PositionalEncodingLearnedLookAheadSplit(embed_dim, resolution, spatial_dim)
+        spatial_encoding = PositionalEncodingLearnedLookAheadSplit(embed_dim, resolution)
     else:
         raise ValueError(f"ERROR: {positional_encoding} encoding not implemented.")
 
diff --git a/modules/utils/embedding.py b/modules/utils/embedding.py
index bb0cdd8fefd5d7d40cea12ccbceee1310f03339b..88c8e14b47a453bf116619e0a438659df78d375f 100644
--- a/modules/utils/embedding.py
+++ b/modules/utils/embedding.py
@@ -3,7 +3,7 @@ import torch.nn as nn
 
 
 class PositionalEncodingLearned(nn.Module):
-    def __init__(self, embed_dim, resolution, spatial_dim):
+    def __init__(self, embed_dim, resolution):
         """ Performs an embedding of token sequences into an embedding space of higher dimension.
 
         Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
@@ -11,13 +11,12 @@ class PositionalEncodingLearned(nn.Module):
         Args:
             embed_dim: Dimension of returned embedding space.
             resolution: Spatial resolution of sequence encoding.
-            spatial_dim: Spatial dimension (2D, 3D, ...) of sequence encoding.
         """
         super(PositionalEncodingLearned, self).__init__()
 
         self.embed_dim = embed_dim
         self.spatial_embeddings = nn.ModuleList(
-            [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(spatial_dim)]
+            [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(3)]
         )
 
     def forward(self, position):
@@ -36,7 +35,7 @@ class PositionalEncodingLearned(nn.Module):
 
 
 class PositionalEncodingLearnedLookAhead(nn.Module):
-    def __init__(self, embed_dim, resolution, spatial_dim):
+    def __init__(self, embed_dim, resolution):
         """ Performs an embedding of token sequences into an embedding space of higher dimension.
 
         Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
@@ -44,13 +43,12 @@ class PositionalEncodingLearnedLookAhead(nn.Module):
         Args:
             embed_dim: Dimension of returned embedding space.
             resolution: Spatial resolution of sequence encoding.
-            spatial_dim: Spatial dimension (2D, 3D, ...) of sequence encoding.
         """
         super(PositionalEncodingLearnedLookAhead, self).__init__()
 
         self.embed_dim = embed_dim
         self.spatial_embeddings = nn.ModuleList(
-            [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(spatial_dim)]
+            [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(3)]
         )
         # end of sequence positional token
         self.eos = torch.nn.Parameter(torch.zeros(embed_dim))
@@ -82,7 +80,7 @@ class PositionalEncodingLearnedLookAhead(nn.Module):
 
 
 class PositionalEncodingLearnedLookAheadSplit(nn.Module):
-    def __init__(self, embed_dim, resolution, spatial_dim):
+    def __init__(self, embed_dim, resolution):
         """ Performs an embedding of token sequences into an embedding space of higher dimension.
 
         Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
@@ -90,16 +88,15 @@ class PositionalEncodingLearnedLookAheadSplit(nn.Module):
         Args:
             embed_dim: Dimension of returned embedding space.
             resolution: Spatial resolution of sequence encoding.
-            spatial_dim: Spatial dimension (2D, 3D, ...) of sequence encoding.
         """
         super(PositionalEncodingLearnedLookAheadSplit, self).__init__()
 
         self.embed_dim = embed_dim
         self.spatial_embeddings = nn.ModuleList(
-            [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(spatial_dim)]
+            [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(3)]
         )
         self.spatial_embeddings_look_ahead = nn.ModuleList(
-            [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(spatial_dim)]
+            [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(3)]
         )
         # end of sequence positional token
         self.eos = torch.nn.Parameter(torch.zeros(embed_dim))
diff --git a/modules/utils/linear.py b/modules/utils/linear.py
index 6bb4c2e4ab4b4c1805fc465b99858071ba541318..d797b2374e3aedddc255a568b10bf3184415225f 100644
--- a/modules/utils/linear.py
+++ b/modules/utils/linear.py
@@ -9,4 +9,4 @@ class Linear(nn.Linear):
             embed_dim: Dimension of returned embedding space.
             num_vocab: Number of different token values (exclusive padding token '0').
         """
-        super(Linear, self).__init__(embed_dim, num_vocab + 1, bias=False)
+        super(Linear, self).__init__(embed_dim, num_vocab + 1, bias=True)