diff --git a/modules/generative_head/composite_head_A.py b/modules/generative_head/composite_head_A.py index 5d09c79b9b90c1dd980c2f32b9f31ae60e40e861..072c6ada1cec5eebb48a7a284aea3a8738dbd3de 100644 --- a/modules/generative_head/composite_head_A.py +++ b/modules/generative_head/composite_head_A.py @@ -9,7 +9,7 @@ from .substitution_head import SubstitutionHead, SubstitutionHeadAutoregressive class CompositeHeadA(nn.Module): - def __init__(self, spatial_encoding, num_vocab, embed_dim, resolution, spatial_dim, **_): + def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution, **_): """ Performs a transformation from transformer latent space into target value logits. Uses a different heads for each depth layer, possibly increasing the overall sequence lenght. @@ -17,9 +17,10 @@ class CompositeHeadA(nn.Module): Args: num_vocab: Number of different target token values (exclusive padding token '0'). - embded_dim: Dimension of the latent embedding space of the transformer. + embed_dim: Dimension of the latent embedding space of the transformer. + head_dim: Size of embedding dimensions used in the head layers. + n_layer: Number of layers used in each linear or convolution block. resolution: Spatial resolution of sequence encoding. - spatial_dim: Spatial dimension (2D/3D) of the sequence data. """ super(CompositeHeadA, self).__init__() @@ -27,7 +28,8 @@ class CompositeHeadA(nn.Module): "spatial_encoding": spatial_encoding, "num_vocab": num_vocab, "embed_dim": embed_dim, - "spatial_dim": spatial_dim, + "head_dim": head_dim, + "n_layer": n_layer } modules = [] @@ -38,15 +40,15 @@ class CompositeHeadA(nn.Module): if resolution >= 8: modules += [LinearHead(**kwargs)] if resolution >= 16: - modules += [ConvolutionHeadA(**kwargs, conv_size=2 ** (spatial_dim - 1))] + modules += [ConvolutionHeadA(**kwargs, conv_size=4)] if resolution >= 32: - modules += [ConvolutionHeadA(**kwargs, conv_size=2 ** spatial_dim)] + modules += [ConvolutionHeadA(**kwargs, conv_size=8)] if resolution >= 64: - modules += [SubstitutionHead(**kwargs, conv_size=2 ** spatial_dim)] + modules += [SubstitutionHead(**kwargs, conv_size=8)] if resolution >= 128: - modules += [DoubleSubstitutionHead(**kwargs, conv_size=2 ** spatial_dim)] + modules += [DoubleSubstitutionHead(**kwargs, conv_size=8)] if resolution >= 256: - modules += [DoubleSubstitutionHead(**kwargs, conv_size=2 ** spatial_dim)] + modules += [DoubleSubstitutionHead(**kwargs, conv_size=8)] # embeddings self.heads = nn.ModuleList(modules) @@ -55,11 +57,11 @@ class CompositeHeadA(nn.Module): 1: 1, 2: 1, 3: 1, - 4: 2 ** (spatial_dim - 1), - 5: 2 ** spatial_dim, - 6: 2 ** spatial_dim, # Note: 'substitution' - 7: 2 ** spatial_dim, # Note: 'double_substitution' - 8: 2 ** spatial_dim, # Note: 'double_substitution' + 4: 4, + 5: 8, + 6: 8, # Note: 'substitution' + 7: 8, # Note: 'double_substitution' + 8: 8, # Note: 'double_substitution' } def forward(self, x, value, depth, position): @@ -154,7 +156,7 @@ class CompositeHeadA(nn.Module): class CompositeHeadAutoregressiveA(CompositeHeadA): - def __init__(self, spatial_encoding, num_vocab, embed_dim, resolution, spatial_dim, **_): + def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution, **_): """ Performs a transformation from transformer latent space into target value logits. Uses a different heads for each depth layer, possibly increasing the overall sequence lenght. @@ -163,17 +165,19 @@ class CompositeHeadAutoregressiveA(CompositeHeadA): Args: num_vocab: Number of different target token values (exclusive padding token '0'). embded_dim: Dimension of the latent embedding space of the transformer. + head_dim: Size of embedding dimensions used in the head layers. + n_layer: Number of layers used in each linear or convolution block. resolution: Spatial resolution of sequence encoding. - spatial_dim: Spatial dimension (2D/3D) of the sequence data. """ - super(CompositeHeadAutoregressiveA, self).__init__(spatial_encoding, num_vocab, embed_dim, resolution, - spatial_dim, **_) + super(CompositeHeadAutoregressiveA, self).__init__(spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, + resolution, **_) kwargs = { "spatial_encoding": spatial_encoding, "num_vocab": num_vocab, "embed_dim": embed_dim, - "spatial_dim": spatial_dim, + "head_dim": head_dim, + "n_layer": n_layer } modules = [] @@ -184,15 +188,15 @@ class CompositeHeadAutoregressiveA(CompositeHeadA): if resolution >= 8: modules += [LinearHead(**kwargs)] if resolution >= 16: - modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=2 ** (spatial_dim - 1))] + modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=4)] if resolution >= 32: - modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=2 ** spatial_dim)] + modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=8)] if resolution >= 64: - modules += [SubstitutionHeadAutoregressive(**kwargs, conv_size=2 ** spatial_dim)] + modules += [SubstitutionHeadAutoregressive(**kwargs, conv_size=8)] if resolution >= 128: - modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=2 ** spatial_dim)] + modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=8)] if resolution >= 256: - modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=2 ** spatial_dim)] + modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=8)] # embeddings self.heads = nn.ModuleList(modules) @@ -201,9 +205,9 @@ class CompositeHeadAutoregressiveA(CompositeHeadA): 1: 1, 2: 1, 3: 1, - 4: 2 ** (spatial_dim - 1), - 5: 2 ** spatial_dim, - 6: 2 ** spatial_dim, # Note: 'substitution' - 7: 2 ** spatial_dim, # Note: 'double_substitution' - 8: 2 ** spatial_dim, # Note: 'double_substitution' + 4: 4, + 5: 8, + 6: 8, # Note: 'substitution' + 7: 8, # Note: 'double_substitution' + 8: 8, # Note: 'double_substitution' } diff --git a/modules/generative_head/composite_head_B.py b/modules/generative_head/composite_head_B.py index b8eec0d472124fba321bb1697bfbdefc08dcf340..e980929ead74c2fc46e4ffd50be7164cb5c0ba2f 100644 --- a/modules/generative_head/composite_head_B.py +++ b/modules/generative_head/composite_head_B.py @@ -8,7 +8,7 @@ from .double_substitution_head import DoubleSubstitutionHead class CompositeHeadB(CompositeHeadA): - def __init__(self, spatial_encoding, num_vocab, embed_dim, resolution, spatial_dim, **_): + def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution, **_): """ Performs a transformation from transformer latent space into target value logits. Uses a different heads for each depth layer, possibly increasing the overall sequence lenght. @@ -16,17 +16,19 @@ class CompositeHeadB(CompositeHeadA): Args: num_vocab: Number of different target token values (exclusive padding token '0'). - embded_dim: Dimension of the latent embedding space of the transformer. + embed_dim: Dimension of the latent embedding space of the transformer. + head_dim: Size of embedding dimensions used in the head layers. + n_layer: Number of layers used in each linear or convolution block. resolution: Spatial resolution of sequence encoding. - spatial_dim: Spatial dimension (2D/3D) of the sequence data. """ - super(CompositeHeadB, self).__init__(num_vocab, embed_dim, resolution, spatial_dim) + super(CompositeHeadB, self).__init__(spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution) kwargs = { "spatial_encoding": spatial_encoding, "num_vocab": num_vocab, "embed_dim": embed_dim, - "spatial_dim": spatial_dim, + "head_dim": head_dim, + "n_layer": n_layer } modules = [] @@ -37,13 +39,13 @@ class CompositeHeadB(CompositeHeadA): if resolution >= 8: modules += [LinearHead(**kwargs)] if resolution >= 16: - modules += [ConvolutionHeadA(**kwargs, conv_size=2**(spatial_dim - 2))] + modules += [ConvolutionHeadA(**kwargs, conv_size=4)] if resolution >= 32: - modules += [ConvolutionHeadA(**kwargs, conv_size=2**spatial_dim)] + modules += [ConvolutionHeadA(**kwargs, conv_size=8)] if resolution >= 64: - modules += [SubstitutionHead(**kwargs, conv_size=2**spatial_dim)] + modules += [SubstitutionHead(**kwargs, conv_size=8)] if resolution >= 128: - modules += [DoubleSubstitutionHead(**kwargs, conv_size=2**spatial_dim)] + modules += [DoubleSubstitutionHead(**kwargs, conv_size=8)] # embeddings self.heads = nn.ModuleList(modules) @@ -52,8 +54,8 @@ class CompositeHeadB(CompositeHeadA): 1: 1, 2: 1, 3: 1, - 4: 2**(spatial_dim - 2), - 5: 2**spatial_dim, - 6: 2**spatial_dim, # Note: 'substitution' - 7: 2**spatial_dim, # Note: 'double_substitution' + 4: 4, + 5: 8, + 6: 8, # Note: 'substitution' + 7: 8, # Note: 'double_substitution' } diff --git a/modules/generative_head/composite_head_C.py b/modules/generative_head/composite_head_C.py index 601be52d407c1871e6df68510ff2a73180f117e0..65effd645cb55984790df24e5719656f7889e479 100644 --- a/modules/generative_head/composite_head_C.py +++ b/modules/generative_head/composite_head_C.py @@ -9,7 +9,7 @@ from .double_substitution_head import DoubleSubstitutionHead class CompositeHeadC(nn.Module): - def __init__(self, spatial_encoding, num_vocab, embed_dim, resolution, spatial_dim, **_): + def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution, **_): """ Performs a transformation from transformer latent space into target value logits. Uses a different heads for each depth layer, possibly increasing the overall sequence lenght. @@ -17,9 +17,10 @@ class CompositeHeadC(nn.Module): Args: num_vocab: Number of different target token values (exclusive padding token '0'). - embded_dim: Dimension of the latent embedding space of the transformer. + embed_dim: Dimension of the latent embedding space of the transformer. + head_dim: Size of embedding dimensions used in the head layers. + n_layer: Number of layers used in each linear or convolution block. resolution: Spatial resolution of sequence encoding. - spatial_dim: Spatial dimension (2D/3D) of the sequence data. """ super(CompositeHeadC, self).__init__() @@ -27,7 +28,8 @@ class CompositeHeadC(nn.Module): "spatial_encoding": spatial_encoding, "num_vocab": num_vocab, "embed_dim": embed_dim, - "spatial_dim": spatial_dim, + "head_dim": head_dim, + "n_layer": n_layer } modules = [] @@ -36,13 +38,13 @@ class CompositeHeadC(nn.Module): if resolution >= 4: modules += [LinearHead(**kwargs)] if resolution >= 8: - modules += [ConvolutionHeadA(**kwargs, conv_size=2**(spatial_dim - 1))] + modules += [ConvolutionHeadA(**kwargs, conv_size=4)] if resolution >= 16: - modules += [ConvolutionHeadA(**kwargs, conv_size=2**spatial_dim)] + modules += [ConvolutionHeadA(**kwargs, conv_size=8)] if resolution >= 32: - modules += [SubstitutionHead(**kwargs, conv_size=2**spatial_dim)] + modules += [SubstitutionHead(**kwargs, conv_size=8)] if resolution >= 64: - modules += [DoubleSubstitutionHead(**kwargs, conv_size=2**spatial_dim)] + modules += [DoubleSubstitutionHead(**kwargs, conv_size=8)] # embeddings self.heads = nn.ModuleList(modules) @@ -50,10 +52,10 @@ class CompositeHeadC(nn.Module): self.reduction_factor = { 1: 1, 2: 1, - 3: 2**(spatial_dim - 1), - 4: 2**spatial_dim, - 5: 2**spatial_dim, # Note: 'substitution' - 6: 2**spatial_dim, # Note: 'double_substitution' + 3: 4, + 4: 8, + 5: 8, # Note: 'substitution' + 6: 8, # Note: 'double_substitution' } def forward(self, x, value, depth, position): diff --git a/modules/generative_head/convolution_head_A.py b/modules/generative_head/convolution_head_A.py index c090fb4b8daa80979fa2ce2b0ae98a36e49078c8..140321497fb066e530d2a4fab1ccb9fb08297a54 100644 --- a/modules/generative_head/convolution_head_A.py +++ b/modules/generative_head/convolution_head_A.py @@ -1,10 +1,10 @@ import torch.nn as nn -from ..utils import Deconvolution, BlockConvolution, Linear +from ..utils import Deconvolution, Convolution, BlockConvolution, Linear class ConvolutionHeadA(nn.Module): - def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_): + def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_): """ Performs a convolutional transformation from transformer latent space into target value logits. Note: The token value '0' is reserved as a padding value, which does not propagate gradients. @@ -12,13 +12,24 @@ class ConvolutionHeadA(nn.Module): Args: num_vocab: Number of different target token values (exclusive padding token '0'). embded_dim: Dimension of the latent embedding space of the transformer. + head_dim: Size of embedding dimensions used in the head layers. + n_layer: Number of layers used in each linear or convolution block. spatial_dim: Spatial dimension (2D/3D) of the sequence data. conv_size: Convolution kernel size and stride. """ super(ConvolutionHeadA, self).__init__() - self.deconvolution = Deconvolution(embed_dim, embed_dim, conv_size) - self.linear = Linear(embed_dim, num_vocab) + deconvolution = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)] + for i in range(n_layer - 1): + deconvolution += [nn.GELU(), Convolution(head_dim, head_dim, (1,))] + self.deconvolution = nn.Sequential(*deconvolution) + + linear = [] + for i in range(n_layer - 1): + linear += [nn.GELU(), nn.Linear(head_dim, head_dim)] + linear += [nn.GELU(), Linear(head_dim, num_vocab)] + self.linear = nn.Sequential(*linear) + self.spatial_encoding = spatial_encoding def forward(self, x, value, depth, pos): @@ -45,7 +56,7 @@ class ConvolutionHeadA(nn.Module): class ConvolutionHeadAutoregressive(nn.Module): - def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_): + def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_): """ Performs a convolutional transformation from transformer latent space into target value logits. Note: The token value '0' is reserved as a padding value, which does not propagate gradients. @@ -53,17 +64,33 @@ class ConvolutionHeadAutoregressive(nn.Module): Args: num_vocab: Number of different target token values (exclusive padding token '0'). embded_dim: Dimension of the latent embedding space of the transformer. + head_dim: Size of embedding dimensions used in the head layers. + n_layer: Number of layers used in each linear or convolution block. spatial_dim: Spatial dimension (2D/3D) of the sequence data. conv_size: Convolution kernel size and stride. """ super(ConvolutionHeadAutoregressive, self).__init__() self.conv_size = conv_size - self.deconvolution = Deconvolution(embed_dim, embed_dim, conv_size) - self.convolution = BlockConvolution(embed_dim, embed_dim, conv_size) - self.linear = Linear(embed_dim, num_vocab) + + deconvolution = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)] + for i in range(n_layer - 1): + deconvolution += [nn.GELU(), Convolution(head_dim, head_dim, (1,))] + self.deconvolution = nn.Sequential(*deconvolution) + + convolution = [BlockConvolution(head_dim, head_dim, conv_size)] + for i in range(n_layer - 1): + convolution += [nn.GELU(), BlockConvolution(head_dim, head_dim, conv_size)] + self.convolution = nn.Sequential(*convolution) + + linear = [] + for i in range(n_layer - 1): + linear += [nn.GELU(), nn.Linear(head_dim, head_dim)] + linear += [nn.GELU(), Linear(head_dim, num_vocab)] + self.linear = nn.Sequential(*linear) + self.spatial_encoding = spatial_encoding - self.value_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0) + self.value_embedding = nn.Embedding(num_vocab + 1, head_dim, padding_idx=0) def forward(self, x, value, depth, pos): """ Transforms the output of the transformer target value logits. diff --git a/modules/generative_head/double_substitution_head.py b/modules/generative_head/double_substitution_head.py index e9064e0ee719f56439f6b9c14dacca3e09d9b7f4..2807f616f2f4c9ed809a878b86b487841a999ef2 100644 --- a/modules/generative_head/double_substitution_head.py +++ b/modules/generative_head/double_substitution_head.py @@ -1,11 +1,11 @@ import torch import torch.nn as nn -from ..utils import Deconvolution, BlockConvolution, Linear +from ..utils import Deconvolution, Convolution, BlockConvolution, Linear class DoubleSubstitutionHead(nn.Module): - def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_): + def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_): """ Performs a twice a substitution transformation from transformer latent space into target value logits. Note: The token value '0' is reserved as a padding value, which does not propagate gradients. @@ -13,20 +13,37 @@ class DoubleSubstitutionHead(nn.Module): Args: num_vocab: Number of different target token values (exclusive padding token '0'). embded_dim: Dimension of the latent embedding space of the transformer. + head_dim: Size of embedding dimensions used in the head layers. + n_layer: Number of layers used in each linear or convolution block. spatial_dim: Spatial dimension (2D/3D) of the sequence data. conv_size: Convolution kernel size and stride. """ super(DoubleSubstitutionHead, self).__init__() - self.embed_dim = embed_dim + self.head_dim = head_dim # deconvolutions - self.deconvolution_2 = Deconvolution(embed_dim, embed_dim, conv_size) - self.deconvolution_1 = Deconvolution(embed_dim, embed_dim, conv_size) - self.deconvolution_0 = Deconvolution(embed_dim, embed_dim, conv_size) - self.spatial_encoding = spatial_encoding + deconvolution_2 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)] + for i in range(n_layer - 1): + deconvolution_2 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] + self.deconvolution_2 = nn.Sequential(*deconvolution_2) + + deconvolution_1 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)] + for i in range(n_layer - 1): + deconvolution_1 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] + self.deconvolution_1 = nn.Sequential(*deconvolution_1) + + deconvolution_0 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)] + for i in range(n_layer - 1): + deconvolution_0 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] + self.deconvolution_0 = nn.Sequential(*deconvolution_0) + + linear = [] + for i in range(n_layer - 1): + linear += [nn.GELU(), nn.Linear(head_dim, head_dim)] + linear += [nn.GELU(), Linear(head_dim, num_vocab)] + self.linear = nn.Sequential(*linear) - # head - self.linear = Linear(embed_dim, num_vocab) + self.spatial_encoding = spatial_encoding def forward(self, x, value, depth, pos): """ Transforms the output of the transformer target value logits. @@ -64,8 +81,8 @@ class DoubleSubstitutionHead(nn.Module): mix_2 = torch.sum(val_2 == 2, dim=1) # create intermediate list to hold vectors - x_0 = torch.zeros((batch_size, torch.max(mix_1), self.embed_dim), device=value.device) - x_1 = torch.zeros((batch_size, torch.max(mix_2), self.embed_dim), device=value.device) + x_0 = torch.zeros((batch_size, torch.max(mix_1), self.head_dim), device=value.device) + x_1 = torch.zeros((batch_size, torch.max(mix_2), self.head_dim), device=value.device) # deconvolute the latent space - sequence length equals number of tokens in the penultimate layer y_2 = self.deconvolution_2(x) @@ -95,7 +112,7 @@ class DoubleSubstitutionHead(nn.Module): class DoubleSubstitutionHeadAutoRegressive(nn.Module): - def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_): + def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_): """ Performs a twice a substitution transformation from transformer latent space into target value logits. Note: The token value '0' is reserved as a padding value, which does not propagate gradients. @@ -103,26 +120,53 @@ class DoubleSubstitutionHeadAutoRegressive(nn.Module): Args: num_vocab: Number of different target token values (exclusive padding token '0'). embded_dim: Dimension of the latent embedding space of the transformer. + head_dim: Size of embedding dimensions used in the head layers. + n_layer: Number of layers used in each linear or convolution block. spatial_dim: Spatial dimension (2D/3D) of the sequence data. conv_size: Convolution kernel size and stride. """ super(DoubleSubstitutionHeadAutoRegressive, self).__init__() - self.embed_dim = embed_dim + self.head_dim = head_dim self.conv_size = conv_size - # deconvolutions - self.deconvolution_2 = Deconvolution(embed_dim, embed_dim, conv_size) - self.deconvolution_1 = Deconvolution(embed_dim, embed_dim, conv_size) - self.deconvolution_0 = Deconvolution(embed_dim, embed_dim, conv_size) + deconvolution_2 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)] + for i in range(n_layer - 1): + deconvolution_2 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] + self.deconvolution_2 = nn.Sequential(*deconvolution_2) + + deconvolution_1 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)] + for i in range(n_layer - 1): + deconvolution_1 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] + self.deconvolution_1 = nn.Sequential(*deconvolution_1) + + deconvolution_0 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)] + for i in range(n_layer - 1): + deconvolution_0 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] + self.deconvolution_0 = nn.Sequential(*deconvolution_0) + + convolution_2 = [] + for i in range(n_layer): + convolution_2 += [nn.GELU(), BlockConvolution(head_dim, head_dim, conv_size)] + self.convolution_2 = nn.Sequential(*convolution_2) + + convolution_1 = [] + for i in range(n_layer): + convolution_1 += [nn.GELU(), BlockConvolution(head_dim, head_dim, 8)] + self.convolution_1 = nn.Sequential(*convolution_1) + + convolution_0 = [BlockConvolution(head_dim, head_dim, 8)] + for i in range(n_layer - 1): + convolution_0 += [nn.GELU(), BlockConvolution(head_dim, head_dim, 8)] + self.convolution_0 = nn.Sequential(*convolution_0) + + linear = [] + for i in range(n_layer - 1): + linear += [nn.GELU(), nn.Linear(head_dim, head_dim)] + linear += [nn.GELU(), Linear(head_dim, num_vocab)] + self.linear = nn.Sequential(*linear) - self.convolution_2 = BlockConvolution(embed_dim, embed_dim, conv_size) - self.convolution_1 = BlockConvolution(embed_dim, embed_dim, conv_size) - self.convolution_0 = BlockConvolution(embed_dim, embed_dim, conv_size) - - # head - self.linear = Linear(embed_dim, num_vocab) self.spatial_encoding = spatial_encoding - self.value_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0) + self.value_embedding = nn.Embedding(num_vocab + 1, head_dim, padding_idx=0) def forward(self, x, value, depth, pos): """ Transforms the output of the transformer target value logits. @@ -168,19 +212,19 @@ class DoubleSubstitutionHeadAutoRegressive(nn.Module): emb_0 = emb_0 + self.spatial_encoding(pos[:, -len_0:]) emb_0 = self.convolution_0(emb_0) - emb_1 = torch.zeros((batch_size, torch.max(len_1), self.embed_dim), dtype=torch.float, device=value.device) + emb_1 = torch.zeros((batch_size, torch.max(len_1), self.head_dim), dtype=torch.float, device=value.device) # substitute all mixed token embeddings of penultimate layer, with token embeddings of last layer emb_1[val_1 == 2] = emb_0[:, (self.conv_size - 1)::self.conv_size] # [N, T1, C] emb_1 = self.convolution_1(emb_1) - emb_2 = torch.zeros((batch_size, torch.max(len_2), self.embed_dim), dtype=torch.float, device=value.device) + emb_2 = torch.zeros((batch_size, torch.max(len_2), self.head_dim), dtype=torch.float, device=value.device) # substitute all mixed token embeddings of third to last layer, with token embeddings of penultimate layer emb_2[val_2 == 2] = emb_1[:, (self.conv_size - 1)::self.conv_size] # [N, T1, C] emb_2 = self.convolution_2(emb_2) # create intermediate list to hold vectors - x_0 = torch.zeros((batch_size, torch.max(mix_1), self.embed_dim), device=value.device) - x_1 = torch.zeros((batch_size, torch.max(mix_2), self.embed_dim), device=value.device) + x_0 = torch.zeros((batch_size, torch.max(mix_1), self.head_dim), device=value.device) + x_1 = torch.zeros((batch_size, torch.max(mix_2), self.head_dim), device=value.device) # deconvolute the latent space - sequence length equals number of tokens in the penultimate layer y_2 = self.deconvolution_2(x) @@ -205,7 +249,7 @@ class DoubleSubstitutionHeadAutoRegressive(nn.Module): # add spatial decoding if available if self.spatial_encoding is not None: len_last = torch.sum(depth == max_depth, dim=1) - assert((depth[:, -len_last:] == max_depth).all()) + assert ((depth[:, -len_last:] == max_depth).all()) y_0 = y_0 + self.spatial_encoding(pos[:, -len_last:]) # compute logits of generated tokens diff --git a/modules/generative_head/head_factory.py b/modules/generative_head/head_factory.py index 3d48aa4f490b070b6a108f0716dae3a7f8a89358..52bbd1422e2bfaaabef35ddba97b3075336cad3c 100644 --- a/modules/generative_head/head_factory.py +++ b/modules/generative_head/head_factory.py @@ -12,7 +12,7 @@ from .multi_conv_head_A import MultiConvolutionHeadA from .substitution_head import SubstitutionHead -def _create_head(name, positional_encoding, num_vocab, embed_dim, resolution, spatial_dim): +def _create_head(name, positional_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution): """ Creates a generative head. If the module specified in `name` does not exist raises a value error. @@ -22,8 +22,9 @@ def _create_head(name, positional_encoding, num_vocab, embed_dim, resolution, sp name: Defines which generative head will be created. num_vocab: Number of different vocabs in the vocabulary set. embed_dim: Size of embedding dimensions used by the transformer model. + head_dim: Size of embedding dimensions used in the head layers. + n_layer: Number of layers used in each linear or convolution block. resolution: Spatial resolution of sequence encoding. - spatial_dim: Spatial dimensionality of input data. Return: Generative head initialised with specified parameters. @@ -32,11 +33,11 @@ def _create_head(name, positional_encoding, num_vocab, embed_dim, resolution, sp if positional_encoding == 'None': spatial_encoding = None elif positional_encoding == 'basic': - spatial_encoding = PositionalEncodingLearned(embed_dim, resolution, spatial_dim) + spatial_encoding = PositionalEncodingLearned(head_dim, resolution) elif positional_encoding == 'look_ahead': - spatial_encoding = PositionalEncodingLearnedLookAhead(embed_dim, resolution, spatial_dim) + spatial_encoding = PositionalEncodingLearnedLookAhead(head_dim, resolution) elif positional_encoding == 'look_ahead_split': - spatial_encoding = PositionalEncodingLearnedLookAheadSplit(embed_dim, resolution, spatial_dim) + spatial_encoding = PositionalEncodingLearnedLookAheadSplit(head_dim, resolution) else: raise ValueError(f"ERROR: {positional_encoding} encoding not implemented.") @@ -44,18 +45,19 @@ def _create_head(name, positional_encoding, num_vocab, embed_dim, resolution, sp "spatial_encoding": spatial_encoding, "num_vocab": num_vocab, "embed_dim": embed_dim, + "head_dim": head_dim, + "n_layer": n_layer, "resolution": resolution, - "spatial_dim": spatial_dim, - "conv_size": 2 ** spatial_dim, + "conv_size": 8, } if name in ('generative_basic', 'linear', 'basic'): return LinearHead(**kwargs) elif name == 'discrete_transformation': - kwargs["num_vocab"] = num_vocab ** 2 ** spatial_dim + kwargs["num_vocab"] = num_vocab ** 2 ** 3 return LinearHead(**kwargs) elif name in ('half_conv', 'half_conv_A'): - kwargs["conv_size"] = 2 ** (spatial_dim - 1) + kwargs["conv_size"] = 2 ** (3 - 1) return ConvolutionHeadA(**kwargs) elif name in ('single_conv', 'single_conv_A'): return ConvolutionHeadA(**kwargs) @@ -77,7 +79,7 @@ def _create_head(name, positional_encoding, num_vocab, embed_dim, resolution, sp raise ValueError(f"ERROR: {name} head not implemented.") -def create_head(name, positional_encoding, num_vocab, embed_dim, resolution, spatial_dim): +def create_head(name, positional_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution): """ Creates a generative head. If `name` is a list, creates a list of heads for each element of the list, otherwise a single one. If the module @@ -89,14 +91,16 @@ def create_head(name, positional_encoding, num_vocab, embed_dim, resolution, spa name: Defines which generative head will be created. num_vocab: Number of different vocabs in the vocabulary set. embed_dim: Size of embedding dimensions used by the transformer model. + head_dim: Size of embedding dimensions used in the head layers. + n_layer: Number of layers used in each linear or convolution block. resolution: Spatial resolution of sequence encoding. - spatial_dim: Spatial dimensionality of input data. Return: Generative head or a list of heads initialised with specified parameters. """ if type(name) == list: return nn.ModuleList( - [_create_head(n, positional_encoding, num_vocab, embed_dim, resolution, spatial_dim) for n in name]) + [_create_head(n, positional_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution) for n + in name]) else: - return _create_head(name, positional_encoding, num_vocab, embed_dim, resolution, spatial_dim) + return _create_head(name, positional_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution) diff --git a/modules/generative_head/linear_head.py b/modules/generative_head/linear_head.py index 98589fc8ae7f512a97e6f8e7ec1dc0ef23f12bf7..de6c21457c346b66ffebecd76a510ff276aa0d2f 100644 --- a/modules/generative_head/linear_head.py +++ b/modules/generative_head/linear_head.py @@ -4,7 +4,7 @@ from ..utils import Linear class LinearHead(nn.Module): - def __init__(self, spatial_encoding, num_vocab, embed_dim, **_): + def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, **_): """ Performs a linear transformation from transformer latent space into target value logits. Note: The token value '0' is reserved as a padding value, which does not propagate gradients. @@ -12,10 +12,20 @@ class LinearHead(nn.Module): Args: num_vocab: Number of different target token values (exclusive padding token '0'). embded_dim: Dimension of the latent embedding space of the transformer. + head_dim: Size of embedding dimensions used in the head layers. + n_layer: Number of layers used in each linear or convolution block. """ super(LinearHead, self).__init__() - self.linear = Linear(embed_dim, num_vocab) + if n_layer > 1: + linear = [nn.GELU(), nn.Linear(embed_dim, head_dim)] + for i in range(n_layer - 2): + linear += [nn.GELU(), nn.Linear(head_dim, head_dim)] + linear += [nn.GELU(), Linear(head_dim, num_vocab)] + self.linear = nn.Sequential(*linear) + else: + self.linear = Linear(embed_dim, num_vocab) + self.spatial_encoding = spatial_encoding def forward(self, x, value, depth, pos): diff --git a/modules/generative_head/substitution_head.py b/modules/generative_head/substitution_head.py index 187b2f16943e0cca4504d22d281f9adc3d973d68..4f74d09b8267fdd5d2d9b0e075eb023f31509cb6 100644 --- a/modules/generative_head/substitution_head.py +++ b/modules/generative_head/substitution_head.py @@ -1,11 +1,11 @@ import torch import torch.nn as nn -from ..utils import BlockConvolution, Deconvolution, Linear +from ..utils import Convolution, BlockConvolution, Deconvolution, Linear class SubstitutionHead(nn.Module): - def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_): + def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_): """ Performs a substitution transformation from transformer latent space into target value logits. Note: The token value '0' is reserved as a padding value, which does not propagate gradients. @@ -13,15 +13,30 @@ class SubstitutionHead(nn.Module): Args: num_vocab: Number of different target token values (exclusive padding token '0'). embded_dim: Dimension of the latent embedding space of the transformer. + head_dim: Size of embedding dimensions used in the head layers. + n_layer: Number of layers used in each linear or convolution block. spatial_dim: Spatial dimension (2D/3D) of the sequence data. conv_size: Convolution kernel size and stride. """ super(SubstitutionHead, self).__init__() - self.embed_dim = embed_dim + self.head_dim = head_dim + + deconvolution_1 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)] + for i in range(n_layer - 1): + deconvolution_1 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] + self.deconvolution_1 = nn.Sequential(*deconvolution_1) + + deconvolution_0 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)] + for i in range(n_layer - 1): + deconvolution_0 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] + self.deconvolution_0 = nn.Sequential(*deconvolution_0) + + linear = [] + for i in range(n_layer - 1): + linear += [nn.GELU(), nn.Linear(head_dim, head_dim)] + linear += [nn.GELU(), Linear(head_dim, num_vocab)] + self.linear = nn.Sequential(*linear) - self.deconvolution_1 = Deconvolution(embed_dim, embed_dim, conv_size) - self.deconvolution_0 = Deconvolution(embed_dim, embed_dim, conv_size) - self.linear = Linear(embed_dim, num_vocab) self.spatial_encoding = spatial_encoding def forward(self, x, value, depth, pos): @@ -56,7 +71,7 @@ class SubstitutionHead(nn.Module): mix_1 = torch.sum(val_1 == 2, dim=1) # create intermediate list to hold vectors - x_0 = torch.zeros((batch_size, torch.max(mix_1), self.embed_dim), device=value.device) + x_0 = torch.zeros((batch_size, torch.max(mix_1), self.head_dim), device=value.device) # deconvolute the latent space - sequence length equals number of tokens in the penultimate layer y_1 = self.deconvolution_1(x) @@ -79,7 +94,7 @@ class SubstitutionHead(nn.Module): class SubstitutionHeadAutoregressive(nn.Module): - def __init__(self, spatial_encoding, num_vocab, embed_dim, spatial_dim, conv_size, **_): + def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_): """ Performs a substitution transformation from transformer latent space into target value logits. Note: The token value '0' is reserved as a padding value, which does not propagate gradients. @@ -87,22 +102,43 @@ class SubstitutionHeadAutoregressive(nn.Module): Args: num_vocab: Number of different target token values (exclusive padding token '0'). embded_dim: Dimension of the latent embedding space of the transformer. + head_dim: Size of embedding dimensions used in the head layers. + n_layer: Number of layers used in each linear or convolution block. spatial_dim: Spatial dimension (2D/3D) of the sequence data. conv_size: Convolution kernel size and stride. """ super(SubstitutionHeadAutoregressive, self).__init__() - self.embed_dim = embed_dim + self.head_dim = head_dim self.conv_size = conv_size - self.deconvolution_1 = Deconvolution(embed_dim, embed_dim, conv_size) - self.deconvolution_0 = Deconvolution(embed_dim, embed_dim, conv_size) + deconvolution_1 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)] + for i in range(n_layer - 1): + deconvolution_1 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] + self.deconvolution_1 = nn.Sequential(*deconvolution_1) + + deconvolution_0 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)] + for i in range(n_layer - 1): + deconvolution_0 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] + self.deconvolution_0 = nn.Sequential(*deconvolution_0) + + convolution_1 = [] + for i in range(n_layer): + convolution_1 += [nn.GELU(), BlockConvolution(head_dim, head_dim, conv_size)] + self.convolution_1 = nn.Sequential(*convolution_1) + + convolution_0 = [BlockConvolution(head_dim, head_dim, 8)] + for i in range(n_layer - 1): + convolution_0 += [nn.GELU(), BlockConvolution(head_dim, head_dim, 8)] + self.convolution_0 = nn.Sequential(*convolution_0) - self.convolution_1 = BlockConvolution(embed_dim, embed_dim, conv_size) - self.convolution_0 = BlockConvolution(embed_dim, embed_dim, conv_size) + linear = [] + for i in range(n_layer - 1): + linear += [nn.GELU(), nn.Linear(head_dim, head_dim)] + linear += [nn.GELU(), Linear(head_dim, num_vocab)] + self.linear = nn.Sequential(*linear) - self.linear = Linear(embed_dim, num_vocab) self.spatial_encoding = spatial_encoding - self.value_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0) + self.value_embedding = nn.Embedding(num_vocab + 1, head_dim, padding_idx=0) def forward(self, x, value, depth, pos): """ Transforms the output of the transformer target value logits. @@ -146,12 +182,12 @@ class SubstitutionHeadAutoregressive(nn.Module): emb_0 = emb_0 + self.spatial_encoding(pos[:, -len_0:]) emb_0 = self.convolution_0(emb_0) - emb_1 = torch.zeros((batch_size, torch.max(len_1), self.embed_dim), dtype=torch.float, device=value.device) + emb_1 = torch.zeros((batch_size, torch.max(len_1), self.head_dim), dtype=torch.float, device=value.device) # substitite all mixed token embeddings of penultimate layer, with token embeddings of last layer emb_1[val_1 == 2] = emb_0[:, (self.conv_size - 1)::self.conv_size] # [N, T1, C] emb_1 = self.convolution_1(emb_1) - x_0 = torch.zeros((batch_size, torch.max(mix_1), self.embed_dim), device=value.device) + x_0 = torch.zeros((batch_size, torch.max(mix_1), self.head_dim), device=value.device) # deconvolute the latent space - sequence length equals number of tokens in the penultimate layer y_1 = self.deconvolution_1(x) diff --git a/modules/shape_transformer.py b/modules/shape_transformer.py index a3a60e91bbce6adeb416605d09ccbffd3c5defc6..a5c096fad061b2e6062b8b0464125d5e4772290d 100644 --- a/modules/shape_transformer.py +++ b/modules/shape_transformer.py @@ -46,6 +46,8 @@ class ShapeTransformer(pl.LightningModule): def __init__( self, embed_dim=16, + head_dim=16, + n_layer_head=1, num_heads=2, num_layers=8, num_positions=512, @@ -88,8 +90,9 @@ class ShapeTransformer(pl.LightningModule): positional_encoding=head_pos_encoding, num_vocab=num_vocab, embed_dim=embed_dim, + head_dim=head_dim, + n_layer=n_layer_head, resolution=resolution, - spatial_dim=spatial_dim, ) # transformer model diff --git a/modules/token_embedding/embedding_factory.py b/modules/token_embedding/embedding_factory.py index 2c5e4baeb5455a90ef3be49859f1e2d7f07a6932..dc6ff55e59d48bdff73851b49ff709e8c1462c73 100644 --- a/modules/token_embedding/embedding_factory.py +++ b/modules/token_embedding/embedding_factory.py @@ -30,11 +30,11 @@ def _create_embedding(name, positional_encoding, num_vocab, embed_dim, resolutio """ if positional_encoding == 'basic': - spatial_encoding = PositionalEncodingLearned(embed_dim, resolution, spatial_dim) + spatial_encoding = PositionalEncodingLearned(embed_dim, resolution) elif positional_encoding == 'look_ahead': - spatial_encoding = PositionalEncodingLearnedLookAhead(embed_dim, resolution, spatial_dim) + spatial_encoding = PositionalEncodingLearnedLookAhead(embed_dim, resolution) elif positional_encoding == 'look_ahead_split': - spatial_encoding = PositionalEncodingLearnedLookAheadSplit(embed_dim, resolution, spatial_dim) + spatial_encoding = PositionalEncodingLearnedLookAheadSplit(embed_dim, resolution) else: raise ValueError(f"ERROR: {positional_encoding} encoding not implemented.") diff --git a/modules/utils/embedding.py b/modules/utils/embedding.py index bb0cdd8fefd5d7d40cea12ccbceee1310f03339b..88c8e14b47a453bf116619e0a438659df78d375f 100644 --- a/modules/utils/embedding.py +++ b/modules/utils/embedding.py @@ -3,7 +3,7 @@ import torch.nn as nn class PositionalEncodingLearned(nn.Module): - def __init__(self, embed_dim, resolution, spatial_dim): + def __init__(self, embed_dim, resolution): """ Performs an embedding of token sequences into an embedding space of higher dimension. Note: The token value '0' is reserved as a padding value, which does not propagate gradients. @@ -11,13 +11,12 @@ class PositionalEncodingLearned(nn.Module): Args: embed_dim: Dimension of returned embedding space. resolution: Spatial resolution of sequence encoding. - spatial_dim: Spatial dimension (2D, 3D, ...) of sequence encoding. """ super(PositionalEncodingLearned, self).__init__() self.embed_dim = embed_dim self.spatial_embeddings = nn.ModuleList( - [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(spatial_dim)] + [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(3)] ) def forward(self, position): @@ -36,7 +35,7 @@ class PositionalEncodingLearned(nn.Module): class PositionalEncodingLearnedLookAhead(nn.Module): - def __init__(self, embed_dim, resolution, spatial_dim): + def __init__(self, embed_dim, resolution): """ Performs an embedding of token sequences into an embedding space of higher dimension. Note: The token value '0' is reserved as a padding value, which does not propagate gradients. @@ -44,13 +43,12 @@ class PositionalEncodingLearnedLookAhead(nn.Module): Args: embed_dim: Dimension of returned embedding space. resolution: Spatial resolution of sequence encoding. - spatial_dim: Spatial dimension (2D, 3D, ...) of sequence encoding. """ super(PositionalEncodingLearnedLookAhead, self).__init__() self.embed_dim = embed_dim self.spatial_embeddings = nn.ModuleList( - [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(spatial_dim)] + [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(3)] ) # end of sequence positional token self.eos = torch.nn.Parameter(torch.zeros(embed_dim)) @@ -82,7 +80,7 @@ class PositionalEncodingLearnedLookAhead(nn.Module): class PositionalEncodingLearnedLookAheadSplit(nn.Module): - def __init__(self, embed_dim, resolution, spatial_dim): + def __init__(self, embed_dim, resolution): """ Performs an embedding of token sequences into an embedding space of higher dimension. Note: The token value '0' is reserved as a padding value, which does not propagate gradients. @@ -90,16 +88,15 @@ class PositionalEncodingLearnedLookAheadSplit(nn.Module): Args: embed_dim: Dimension of returned embedding space. resolution: Spatial resolution of sequence encoding. - spatial_dim: Spatial dimension (2D, 3D, ...) of sequence encoding. """ super(PositionalEncodingLearnedLookAheadSplit, self).__init__() self.embed_dim = embed_dim self.spatial_embeddings = nn.ModuleList( - [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(spatial_dim)] + [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(3)] ) self.spatial_embeddings_look_ahead = nn.ModuleList( - [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(spatial_dim)] + [nn.Embedding(2 * resolution, embed_dim, padding_idx=0) for _ in range(3)] ) # end of sequence positional token self.eos = torch.nn.Parameter(torch.zeros(embed_dim)) diff --git a/modules/utils/linear.py b/modules/utils/linear.py index 6bb4c2e4ab4b4c1805fc465b99858071ba541318..d797b2374e3aedddc255a568b10bf3184415225f 100644 --- a/modules/utils/linear.py +++ b/modules/utils/linear.py @@ -9,4 +9,4 @@ class Linear(nn.Linear): embed_dim: Dimension of returned embedding space. num_vocab: Number of different token values (exclusive padding token '0'). """ - super(Linear, self).__init__(embed_dim, num_vocab + 1, bias=False) + super(Linear, self).__init__(embed_dim, num_vocab + 1, bias=True)