diff --git a/modules/generative_head/composite_head_A.py b/modules/generative_head/composite_head_A.py index 072c6ada1cec5eebb48a7a284aea3a8738dbd3de..eab66115ea82517c283b9495da3b212abbabd0d4 100644 --- a/modules/generative_head/composite_head_A.py +++ b/modules/generative_head/composite_head_A.py @@ -2,10 +2,10 @@ import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence -from .convolution_head_A import ConvolutionHeadA, ConvolutionHeadAutoregressive -from .double_substitution_head import DoubleSubstitutionHead, DoubleSubstitutionHeadAutoRegressive +from .convolution_head import ConvolutionHead +from .double_substitution_head import DoubleSubstitutionHead from .linear_head import LinearHead -from .substitution_head import SubstitutionHead, SubstitutionHeadAutoregressive +from .substitution_head import SubstitutionHead class CompositeHeadA(nn.Module): @@ -40,9 +40,9 @@ class CompositeHeadA(nn.Module): if resolution >= 8: modules += [LinearHead(**kwargs)] if resolution >= 16: - modules += [ConvolutionHeadA(**kwargs, conv_size=4)] + modules += [ConvolutionHead(**kwargs, conv_size=4)] if resolution >= 32: - modules += [ConvolutionHeadA(**kwargs, conv_size=8)] + modules += [ConvolutionHead(**kwargs, conv_size=8)] if resolution >= 64: modules += [SubstitutionHead(**kwargs, conv_size=8)] if resolution >= 128: @@ -153,61 +153,3 @@ class CompositeHeadA(nn.Module): # pad embedding sequence return pad_sequence(out, batch_first=True, padding_value=0.0) - - -class CompositeHeadAutoregressiveA(CompositeHeadA): - def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution, **_): - """ Performs a transformation from transformer latent space into target value logits. - - Uses a different heads for each depth layer, possibly increasing the overall sequence lenght. - Note: The token value '0' is reserved as a padding value, which does not propagate gradients. - - Args: - num_vocab: Number of different target token values (exclusive padding token '0'). - embded_dim: Dimension of the latent embedding space of the transformer. - head_dim: Size of embedding dimensions used in the head layers. - n_layer: Number of layers used in each linear or convolution block. - resolution: Spatial resolution of sequence encoding. - """ - super(CompositeHeadAutoregressiveA, self).__init__(spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, - resolution, **_) - - kwargs = { - "spatial_encoding": spatial_encoding, - "num_vocab": num_vocab, - "embed_dim": embed_dim, - "head_dim": head_dim, - "n_layer": n_layer - } - - modules = [] - if resolution >= 2: - modules += [LinearHead(**kwargs)] - if resolution >= 4: - modules += [LinearHead(**kwargs)] - if resolution >= 8: - modules += [LinearHead(**kwargs)] - if resolution >= 16: - modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=4)] - if resolution >= 32: - modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=8)] - if resolution >= 64: - modules += [SubstitutionHeadAutoregressive(**kwargs, conv_size=8)] - if resolution >= 128: - modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=8)] - if resolution >= 256: - modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=8)] - - # embeddings - self.heads = nn.ModuleList(modules) - - self.reduction_factor = { - 1: 1, - 2: 1, - 3: 1, - 4: 4, - 5: 8, - 6: 8, # Note: 'substitution' - 7: 8, # Note: 'double_substitution' - 8: 8, # Note: 'double_substitution' - } diff --git a/modules/generative_head/composite_head_B.py b/modules/generative_head/composite_head_B.py index e980929ead74c2fc46e4ffd50be7164cb5c0ba2f..354f3c137d9c136eae41c6076c2ee83e4221ff20 100644 --- a/modules/generative_head/composite_head_B.py +++ b/modules/generative_head/composite_head_B.py @@ -1,13 +1,13 @@ +import torch import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence -from .composite_head_A import CompositeHeadA +from .convolution_head import ConvolutionHead from .linear_head import LinearHead -from .convolution_head_A import ConvolutionHeadA from .substitution_head import SubstitutionHead -from .double_substitution_head import DoubleSubstitutionHead -class CompositeHeadB(CompositeHeadA): +class CompositeHeadB(nn.Module): def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution, **_): """ Performs a transformation from transformer latent space into target value logits. @@ -21,7 +21,7 @@ class CompositeHeadB(CompositeHeadA): n_layer: Number of layers used in each linear or convolution block. resolution: Spatial resolution of sequence encoding. """ - super(CompositeHeadB, self).__init__(spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution) + super(CompositeHeadB, self).__init__() kwargs = { "spatial_encoding": spatial_encoding, @@ -39,13 +39,11 @@ class CompositeHeadB(CompositeHeadA): if resolution >= 8: modules += [LinearHead(**kwargs)] if resolution >= 16: - modules += [ConvolutionHeadA(**kwargs, conv_size=4)] + modules += [LinearHead(**kwargs)] if resolution >= 32: - modules += [ConvolutionHeadA(**kwargs, conv_size=8)] + modules += [ConvolutionHead(**kwargs, conv_size=8)] if resolution >= 64: modules += [SubstitutionHead(**kwargs, conv_size=8)] - if resolution >= 128: - modules += [DoubleSubstitutionHead(**kwargs, conv_size=8)] # embeddings self.heads = nn.ModuleList(modules) @@ -54,8 +52,97 @@ class CompositeHeadB(CompositeHeadA): 1: 1, 2: 1, 3: 1, - 4: 4, + 4: 1, 5: 8, 6: 8, # Note: 'substitution' - 7: 8, # Note: 'double_substitution' } + + def forward(self, x, value, depth, position): + """ Transforms the output of the transformer target value logits. + + Args: + x: Output of the transformer, the latent vector [N, T, E]. + value: Target value token sequence [N, T]. + depth: Target depth token sequence [N, T]. + position: Target position token sequence [N, T, A]. + + Return + Logits of target value sequence. + """ + batch_depth = torch.max(depth) + out = [] + + # process each sample individually + for latent_vec, val, dep, pos in zip(x, value, depth, position): + logits = torch.tensor([], device=x.device) + vector_idx = 0 + + # compute logits layerwise + for layer_idx, head in enumerate(self.heads): + layer_depth = layer_idx + 1 + if layer_depth > batch_depth: + break # reached max depth layer + + if layer_depth < 6: + # get value, depth and position sequence of current layer + layer_val = val[dep == layer_depth] + layer_dep = dep[dep == layer_depth] + layer_pos = pos[dep == layer_depth] + # compute number of vectors in latent vector of current layer + num_vectors = torch.sum(dep == layer_depth) // self.reduction_factor[layer_depth] + elif layer_depth == 6: # handle substitution + # get value, depth and position sequence of previous and current layer + layer_val = torch.cat([val[dep == (layer_depth - 1)], val[dep == layer_depth]]) + layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]) + layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]) + # compute number of vectors in latent vector of current layer + num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] + elif layer_depth in (7, 8): # handle substitution + # get value, depth and position sequence of previous and current layer + layer_val = torch.cat( + [ + val[dep == (layer_depth - 2)], + val[dep == (layer_depth - 1)], + val[dep == layer_depth], + ] + ) + layer_dep = torch.cat( + [ + dep[dep == (layer_depth - 2)], + dep[dep == (layer_depth - 1)], + dep[dep == layer_depth], + ] + ) + layer_pos = torch.cat( + [ + pos[dep == (layer_depth - 2)], + pos[dep == (layer_depth - 1)], + pos[dep == layer_depth], + ] + ) + # compute number of vectors in latent vector of current layer + num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth] + + # filter latent vector of current layer + layer_vec = latent_vec[vector_idx:vector_idx + num_vectors] + + # handle clipped values in transformer + if len(layer_vec) == 0: + continue + + # compute layer logits + layer_logits = head( + layer_vec.unsqueeze(0), + layer_val.unsqueeze(0), + layer_dep.unsqueeze(0), + layer_pos.unsqueeze(0), + )[0] + logits = torch.cat([logits, layer_logits]) + + # discard processed tokens + vector_idx += num_vectors + + out += [logits] + + # pad embedding sequence + return pad_sequence(out, batch_first=True, padding_value=0.0) diff --git a/modules/generative_head/composite_head_C.py b/modules/generative_head/composite_head_C.py index 65effd645cb55984790df24e5719656f7889e479..45f23715b84f595c63d1eeab92f1d8fa2435a424 100644 --- a/modules/generative_head/composite_head_C.py +++ b/modules/generative_head/composite_head_C.py @@ -2,10 +2,9 @@ import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence +from .convolution_head import ConvolutionHead from .linear_head import LinearHead -from .convolution_head_A import ConvolutionHeadA from .substitution_head import SubstitutionHead -from .double_substitution_head import DoubleSubstitutionHead class CompositeHeadC(nn.Module): @@ -38,13 +37,13 @@ class CompositeHeadC(nn.Module): if resolution >= 4: modules += [LinearHead(**kwargs)] if resolution >= 8: - modules += [ConvolutionHeadA(**kwargs, conv_size=4)] + modules += [ConvolutionHead(**kwargs, conv_size=2)] if resolution >= 16: - modules += [ConvolutionHeadA(**kwargs, conv_size=8)] + modules += [ConvolutionHead(**kwargs, conv_size=4)] if resolution >= 32: - modules += [SubstitutionHead(**kwargs, conv_size=8)] + modules += [ConvolutionHead(**kwargs, conv_size=8)] if resolution >= 64: - modules += [DoubleSubstitutionHead(**kwargs, conv_size=8)] + modules += [SubstitutionHead(**kwargs, conv_size=4)] # embeddings self.heads = nn.ModuleList(modules) @@ -52,10 +51,10 @@ class CompositeHeadC(nn.Module): self.reduction_factor = { 1: 1, 2: 1, - 3: 4, - 4: 8, - 5: 8, # Note: 'substitution' - 6: 8, # Note: 'double_substitution' + 3: 2, + 4: 4, + 5: 8, + 6: 4, # Note: 'substitution' } def forward(self, x, value, depth, position): @@ -84,21 +83,21 @@ class CompositeHeadC(nn.Module): if layer_depth > batch_depth: break # reached max depth layer - if layer_depth < 5: + if layer_depth < 6: # get value, depth and position sequence of current layer layer_val = val[dep == layer_depth] layer_dep = dep[dep == layer_depth] layer_pos = pos[dep == layer_depth] # compute number of vectors in latent vector of current layer num_vectors = torch.sum(dep == layer_depth) // self.reduction_factor[layer_depth] - elif layer_depth == 5: # handle substitution + elif layer_depth == 6: # handle substitution # get value, depth and position sequence of previous and current layer layer_val = torch.cat([val[dep == (layer_depth - 1)], val[dep == layer_depth]]) layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]) layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]) # compute number of vectors in latent vector of current layer num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] - elif layer_depth == 6: # handle double substitution + elif layer_depth in (7,8): # handle double substitution # get value, depth and position sequence of previous and current layer layer_val = torch.cat( [ diff --git a/modules/generative_head/composite_head_D.py b/modules/generative_head/composite_head_D.py index 2192cfdf8c9eb49d71911d4f75babfed2359f257..b9a28517dedb012e0718fabe441e0f392134c63f 100644 --- a/modules/generative_head/composite_head_D.py +++ b/modules/generative_head/composite_head_D.py @@ -2,10 +2,10 @@ import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence -from .convolution_head_A import ConvolutionHeadA, ConvolutionHeadAutoregressive -from .double_substitution_head import DoubleSubstitutionHead, DoubleSubstitutionHeadAutoRegressive +from .convolution_head import ConvolutionHead +from .double_substitution_head import DoubleSubstitutionHead from .linear_head import LinearHead -from .substitution_head import SubstitutionHead, SubstitutionHeadAutoregressive +from .substitution_head import SubstitutionHead class CompositeHeadD(nn.Module): @@ -38,13 +38,13 @@ class CompositeHeadD(nn.Module): if resolution >= 4: modules += [LinearHead(**kwargs)] if resolution >= 8: - modules += [ConvolutionHeadA(**kwargs, conv_size=2)] + modules += [ConvolutionHead(**kwargs, conv_size=4)] if resolution >= 16: - modules += [ConvolutionHeadA(**kwargs, conv_size=8)] + modules += [ConvolutionHead(**kwargs, conv_size=8)] if resolution >= 32: modules += [SubstitutionHead(**kwargs, conv_size=4)] if resolution >= 64: - modules += [DoubleSubstitutionHead(**kwargs, conv_size=2)] + modules += [SubstitutionHead(**kwargs, conv_size=8)] if resolution >= 128: modules += [DoubleSubstitutionHead(**kwargs, conv_size=4)] if resolution >= 256: @@ -56,10 +56,10 @@ class CompositeHeadD(nn.Module): self.reduction_factor = { 1: 1, 2: 1, - 3: 2, + 3: 4, 4: 8, 5: 4, # Note: 'substitution' - 6: 2, # Note: 'double_substitution' + 6: 8, # Note: 'substitution' 7: 4, # Note: 'double_substitution' 8: 8, # Note: 'double_substitution' } @@ -97,14 +97,14 @@ class CompositeHeadD(nn.Module): layer_pos = pos[dep == layer_depth] # compute number of vectors in latent vector of current layer num_vectors = torch.sum(dep == layer_depth) // self.reduction_factor[layer_depth] - elif layer_depth == 5: # handle substitution + elif layer_depth in (5, 6): # handle substitution # get value, depth and position sequence of previous and current layer layer_val = torch.cat([val[dep == (layer_depth - 1)], val[dep == layer_depth]]) layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]) layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]) # compute number of vectors in latent vector of current layer num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth] - elif layer_depth in (6, 7, 8): # handle double substitution + elif layer_depth in (7, 8): # handle double substitution # get value, depth and position sequence of previous and current layer layer_val = torch.cat( [ @@ -153,61 +153,3 @@ class CompositeHeadD(nn.Module): # pad embedding sequence return pad_sequence(out, batch_first=True, padding_value=0.0) - - -class CompositeHeadAutoregressiveD(CompositeHeadD): - def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution, **_): - """ Performs a transformation from transformer latent space into target value logits. - - Uses a different heads for each depth layer, possibly increasing the overall sequence lenght. - Note: The token value '0' is reserved as a padding value, which does not propagate gradients. - - Args: - num_vocab: Number of different target token values (exclusive padding token '0'). - embded_dim: Dimension of the latent embedding space of the transformer. - head_dim: Size of embedding dimensions used in the head layers. - n_layer: Number of layers used in each linear or convolution block. - resolution: Spatial resolution of sequence encoding. - """ - super(CompositeHeadAutoregressiveD, self).__init__(spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, - resolution, **_) - - kwargs = { - "spatial_encoding": spatial_encoding, - "num_vocab": num_vocab, - "embed_dim": embed_dim, - "head_dim": head_dim, - "n_layer": n_layer - } - - modules = [] - if resolution >= 2: - modules += [LinearHead(**kwargs)] - if resolution >= 4: - modules += [LinearHead(**kwargs)] - if resolution >= 8: - modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=2)] - if resolution >= 16: - modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=8)] - if resolution >= 32: - modules += [SubstitutionHeadAutoregressive(**kwargs, conv_size=4)] - if resolution >= 64: - modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=2)] - if resolution >= 128: - modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=4)] - if resolution >= 256: - modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=8)] - - # embeddings - self.heads = nn.ModuleList(modules) - - self.reduction_factor = { - 1: 1, - 2: 1, - 3: 2, - 4: 8, - 5: 4, # Note: 'substitution' - 6: 2, # Note: 'double_substitution' - 7: 4, # Note: 'double_substitution' - 8: 8, # Note: 'double_substitution' - } diff --git a/modules/generative_head/convolution_head_A.py b/modules/generative_head/convolution_head.py similarity index 53% rename from modules/generative_head/convolution_head_A.py rename to modules/generative_head/convolution_head.py index 140321497fb066e530d2a4fab1ccb9fb08297a54..9f120f6f5ba516eb171160cc77440c62cf3667e6 100644 --- a/modules/generative_head/convolution_head_A.py +++ b/modules/generative_head/convolution_head.py @@ -3,7 +3,7 @@ import torch.nn as nn from ..utils import Deconvolution, Convolution, BlockConvolution, Linear -class ConvolutionHeadA(nn.Module): +class ConvolutionHead(nn.Module): def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_): """ Performs a convolutional transformation from transformer latent space into target value logits. @@ -17,59 +17,7 @@ class ConvolutionHeadA(nn.Module): spatial_dim: Spatial dimension (2D/3D) of the sequence data. conv_size: Convolution kernel size and stride. """ - super(ConvolutionHeadA, self).__init__() - - deconvolution = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)] - for i in range(n_layer - 1): - deconvolution += [nn.GELU(), Convolution(head_dim, head_dim, (1,))] - self.deconvolution = nn.Sequential(*deconvolution) - - linear = [] - for i in range(n_layer - 1): - linear += [nn.GELU(), nn.Linear(head_dim, head_dim)] - linear += [nn.GELU(), Linear(head_dim, num_vocab)] - self.linear = nn.Sequential(*linear) - - self.spatial_encoding = spatial_encoding - - def forward(self, x, value, depth, pos): - """ Transforms the output of the transformer target value logits. - - Args: - x: Output of the transformer, the latent vector [N, T', E]. - value: Target value token sequence [N, T]. - depth: Target depth token sequence [N, T]. - pos: Target position token sequence [N, T, A]. - - Return - Logits of target value sequence. - """ - # deconvolute the latent space - create new tokens - x = self.deconvolution(x) # [N, T, E] - - # add spatial decoding if available - if self.spatial_encoding is not None: - x = x + self.spatial_encoding(pos) - - # compute logits for each token - return self.linear(x) # [N, T, V] - - -class ConvolutionHeadAutoregressive(nn.Module): - def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_): - """ Performs a convolutional transformation from transformer latent space into target value logits. - - Note: The token value '0' is reserved as a padding value, which does not propagate gradients. - - Args: - num_vocab: Number of different target token values (exclusive padding token '0'). - embded_dim: Dimension of the latent embedding space of the transformer. - head_dim: Size of embedding dimensions used in the head layers. - n_layer: Number of layers used in each linear or convolution block. - spatial_dim: Spatial dimension (2D/3D) of the sequence data. - conv_size: Convolution kernel size and stride. - """ - super(ConvolutionHeadAutoregressive, self).__init__() + super(ConvolutionHead, self).__init__() self.conv_size = conv_size diff --git a/modules/generative_head/double_substitution_head.py b/modules/generative_head/double_substitution_head.py index e50f095df1c16d8789346007065b2a68288bf7a4..6444e5a4a890a01685e6a54b40fe511c86d6e208 100644 --- a/modules/generative_head/double_substitution_head.py +++ b/modules/generative_head/double_substitution_head.py @@ -21,113 +21,6 @@ class DoubleSubstitutionHead(nn.Module): super(DoubleSubstitutionHead, self).__init__() self.head_dim = head_dim - # deconvolutions - deconvolution_2 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)] - for i in range(n_layer - 1): - deconvolution_2 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] - self.deconvolution_2 = nn.Sequential(*deconvolution_2) - - deconvolution_1 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)] - for i in range(n_layer - 1): - deconvolution_1 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] - self.deconvolution_1 = nn.Sequential(*deconvolution_1) - - deconvolution_0 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)] - for i in range(n_layer - 1): - deconvolution_0 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] - self.deconvolution_0 = nn.Sequential(*deconvolution_0) - - linear = [] - for i in range(n_layer - 1): - linear += [nn.GELU(), nn.Linear(head_dim, head_dim)] - linear += [nn.GELU(), Linear(head_dim, num_vocab)] - self.linear = nn.Sequential(*linear) - - self.spatial_encoding = spatial_encoding - - def forward(self, x, value, depth, pos): - """ Transforms the output of the transformer target value logits. - - Transforms one token of the latent vector into multiple tokens of the target vector through de-convolutional - operations. In the case of a quadtree one token is responsible for up to 16 target tokens. In the case of a - octree one token is responsible for up to 64 target tokens. Only tokens, which correspond to a mixed target - value token in the penultimate layer are transformed into target sequence tokens. - - Args: - x: Output of the transformer, the latent vector [N, T'', E]. - value: Value token sequence, with penultimate and last layer. - depth: Depth token sequence, with penultimate and last layer. - pos: Position token sequence, with penultimate and last layer. - - Return - Logits of target value sequence. - """ - batch_size = value.shape[0] - max_depth = torch.max(depth) - len_1 = torch.sum(depth == (max_depth - 1), dim=1) - len_2 = torch.sum(depth == (max_depth - 2), dim=1) - - # create intermediate list to hold values - val_1 = torch.zeros((batch_size, torch.max(len_1)), device=value.device) - val_2 = torch.zeros((batch_size, torch.max(len_2)), device=value.device) - - # splitt input in second-last (1) layer - for i in range(batch_size): - val_2[i, :len_2[i]] = value[i, :len_2[i]] - val_1[i, :len_1[i]] = value[i, len_2[i]:len_2[i] + len_1[i]] - - # compute the number of mixed tokens in mask - mix_1 = torch.sum(val_1 == 2, dim=1) - mix_2 = torch.sum(val_2 == 2, dim=1) - - # create intermediate list to hold vectors - x_0 = torch.zeros((batch_size, torch.max(mix_1), self.head_dim), device=value.device) - x_1 = torch.zeros((batch_size, torch.max(mix_2), self.head_dim), device=value.device) - - # deconvolute the latent space - sequence length equals number of tokens in the penultimate layer - y_2 = self.deconvolution_2(x) - # select only latent vectors, which correspond to mixed tokens in third-last layer - for i in range(batch_size): - mix_2_mask_i = (val_2[i] == 2) - x_1[i, :torch.sum(mix_2_mask_i)] = y_2[i, mix_2_mask_i] # [N, T', C] - - # deconvolute the latent space - sequence length equals number of tokens in the penultimate layer - y_1 = self.deconvolution_1(x_1) - # select only latent vectors, which correspond to mixed tokens in third-last layer - for i in range(batch_size): - mix_1_mask_i = (val_1[i] == 2) - x_0[i, :torch.sum(mix_1_mask_i)] = y_1[i, mix_1_mask_i] # [N, T', C] - - # deconvolute the intermediate latent space - create new tokens in latent space for each mixed token - y_0 = self.deconvolution_0(x_0) # [N, T, C] - - # add spatial decoding if available - if self.spatial_encoding is not None: - len_last = torch.sum(depth == max_depth, dim=1) - assert ((depth[:, -len_last:] == max_depth).all()) - y_0 = y_0 + self.spatial_encoding(pos[:, -len_last:]) - - # compute logits of generated tokens - return self.linear(y_0) # [N, T, V] - - -class DoubleSubstitutionHeadAutoRegressive(nn.Module): - def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_): - """ Performs a twice a substitution transformation from transformer latent space into target value logits. - - Note: The token value '0' is reserved as a padding value, which does not propagate gradients. - - Args: - num_vocab: Number of different target token values (exclusive padding token '0'). - embded_dim: Dimension of the latent embedding space of the transformer. - head_dim: Size of embedding dimensions used in the head layers. - n_layer: Number of layers used in each linear or convolution block. - spatial_dim: Spatial dimension (2D/3D) of the sequence data. - conv_size: Convolution kernel size and stride. - """ - super(DoubleSubstitutionHeadAutoRegressive, self).__init__() - self.head_dim = head_dim - deconvolution_2 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)] for i in range(n_layer - 1): deconvolution_2 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] diff --git a/modules/generative_head/head_factory.py b/modules/generative_head/head_factory.py index d7aaa7b0059209e802b516aeaf5a5b3a5c473b38..531c06cb5001a818356414bf14b653f8c2202b11 100644 --- a/modules/generative_head/head_factory.py +++ b/modules/generative_head/head_factory.py @@ -2,11 +2,11 @@ import torch.nn as nn from modules.utils import PositionalEncodingLearned, PositionalEncodingLearnedLookAhead, \ PositionalEncodingLearnedLookAheadSplit -from .composite_head_A import CompositeHeadA, CompositeHeadAutoregressiveA +from .composite_head_A import CompositeHeadA from .composite_head_B import CompositeHeadB from .composite_head_C import CompositeHeadC -from .composite_head_D import CompositeHeadD, CompositeHeadAutoregressiveD -from .convolution_head_A import ConvolutionHeadA +from .composite_head_D import CompositeHeadD +from .convolution_head import ConvolutionHead from .double_substitution_head import DoubleSubstitutionHead from .linear_head import LinearHead from .multi_conv_head_A import MultiConvolutionHeadA @@ -59,9 +59,9 @@ def _create_head(name, positional_encoding, num_vocab, embed_dim, head_dim, n_la return LinearHead(**kwargs) elif name in ('half_conv', 'half_conv_A'): kwargs["conv_size"] = 2 ** (3 - 1) - return ConvolutionHeadA(**kwargs) + return ConvolutionHead(**kwargs) elif name in ('single_conv', 'single_conv_A'): - return ConvolutionHeadA(**kwargs) + return ConvolutionHead(**kwargs) elif name == 'multi_conv_A': return MultiConvolutionHeadA(**kwargs) elif name == 'substitution': @@ -70,16 +70,12 @@ def _create_head(name, positional_encoding, num_vocab, embed_dim, head_dim, n_la return DoubleSubstitutionHead(**kwargs) elif name in ('composite', 'composite_A'): return CompositeHeadA(**kwargs) - elif name in ('composite_autoregressive_A'): - return CompositeHeadAutoregressiveA(**kwargs) elif name in ('composite_B'): return CompositeHeadB(**kwargs) elif name in ('composite_C'): return CompositeHeadC(**kwargs) elif name in ('composite_D'): return CompositeHeadD(**kwargs) - elif name in ('composite_autoregressive_D'): - return CompositeHeadAutoregressiveD(**kwargs) else: raise ValueError(f"ERROR: {name} head not implemented.") diff --git a/modules/generative_head/substitution_head.py b/modules/generative_head/substitution_head.py index 09602f95e5c9f599995c529ec8a841b32695c091..87e943e2276053de176eb3377804e4e511d2fdac 100644 --- a/modules/generative_head/substitution_head.py +++ b/modules/generative_head/substitution_head.py @@ -31,95 +31,6 @@ class SubstitutionHead(nn.Module): deconvolution_0 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] self.deconvolution_0 = nn.Sequential(*deconvolution_0) - linear = [] - for i in range(n_layer - 1): - linear += [nn.GELU(), nn.Linear(head_dim, head_dim)] - linear += [nn.GELU(), Linear(head_dim, num_vocab)] - self.linear = nn.Sequential(*linear) - - self.spatial_encoding = spatial_encoding - - def forward(self, x, value, depth, pos): - """ Transforms the output of the transformer target value logits. - - Transforms one token of the latent vector into multiple tokens of the target vector through de-convolutional - operations. In the case of a quadtree one token is responsible for up to 16 target tokens. In the case of a - octree one token is responsible for up to 64 target tokens. Only tokens, which correspond to a mixed target - value token in the penultimate layer are transformed into target sequence tokens. - - Args: - x: Output of the transformer, the latent vector [N, T'', E]. - value: Value token sequence, with penultimate and last layer. - depth: Depth token sequence, with penultimate and last layer. - pos: Position token sequence, with penultimate and last layer. - - Return - Logits of target value sequence. - """ - batch_size = value.shape[0] - max_depth = torch.max(depth) - len_1 = torch.sum(depth == (max_depth - 1), dim=1) - - # create intermediate list to hold values - val_1 = torch.zeros((batch_size, torch.max(len_1)), device=value.device) - - # splitt input in second-last (1) layer - for i in range(batch_size): - val_1[i, :len_1[i]] = value[i, :len_1[i]] - - # compute the number of mixed tokens in mask - mix_1 = torch.sum(val_1 == 2, dim=1) - - # create intermediate list to hold vectors - x_0 = torch.zeros((batch_size, torch.max(mix_1), self.head_dim), device=value.device) - - # deconvolute the latent space - sequence length equals number of tokens in the penultimate layer - y_1 = self.deconvolution_1(x) - # select only latent vectors, which correspond to mixed tokens in the penultimate layer - for i in range(batch_size): - mix_1_mask_i = (val_1[i] == 2) - x_0[i, :torch.sum(mix_1_mask_i)] = y_1[i, mix_1_mask_i] # [N, T', C] - - # deconvolute the intermediate latent space - create new tokens in latent space for each mixed token - y_0 = self.deconvolution_0(x_0) # [N, T, C] - - # add spatial decoding if available - if self.spatial_encoding is not None: - len_last = torch.sum(depth == max_depth, dim=1) - assert ((depth[:, -len_last:] == max_depth).all()) - y_0 = y_0 + self.spatial_encoding(pos[:, -len_last:]) - - # compute logits of generated tokens - return self.linear(y_0) # [N, T, V] - - -class SubstitutionHeadAutoregressive(nn.Module): - def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, conv_size, **_): - """ Performs a substitution transformation from transformer latent space into target value logits. - - Note: The token value '0' is reserved as a padding value, which does not propagate gradients. - - Args: - num_vocab: Number of different target token values (exclusive padding token '0'). - embded_dim: Dimension of the latent embedding space of the transformer. - head_dim: Size of embedding dimensions used in the head layers. - n_layer: Number of layers used in each linear or convolution block. - spatial_dim: Spatial dimension (2D/3D) of the sequence data. - conv_size: Convolution kernel size and stride. - """ - super(SubstitutionHeadAutoregressive, self).__init__() - self.head_dim = head_dim - - deconvolution_1 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)] - for i in range(n_layer - 1): - deconvolution_1 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] - self.deconvolution_1 = nn.Sequential(*deconvolution_1) - - deconvolution_0 = [nn.GELU(), Deconvolution(head_dim, head_dim, 8)] - for i in range(n_layer - 1): - deconvolution_0 += [nn.GELU(), Convolution(head_dim, head_dim, 1)] - self.deconvolution_0 = nn.Sequential(*deconvolution_0) - convolution_1 = [] for i in range(n_layer): convolution_1 += [nn.GELU(), BlockConvolution(head_dim, head_dim, conv_size)] diff --git a/modules/token_embedding/basic_embedding_A.py b/modules/token_embedding/basic_embedding.py similarity index 95% rename from modules/token_embedding/basic_embedding_A.py rename to modules/token_embedding/basic_embedding.py index 990009285f45b354476377594bfa54e36effd950..13cc57ab2418cd0e2e6fb66621e066484d8e8f87 100644 --- a/modules/token_embedding/basic_embedding_A.py +++ b/modules/token_embedding/basic_embedding.py @@ -3,7 +3,7 @@ import torch.nn as nn from utils.masks import padding_mask -class BasicEmbeddingA(nn.Module): +class BasicEmbedding(nn.Module): def __init__(self, encoding, num_vocab, embed_dim, resolution, spatial_dim, **_): """ Performs an embedding of token sequences into an embedding space of higher dimension. @@ -16,7 +16,7 @@ class BasicEmbeddingA(nn.Module): resolution: Spatial resolution of sequence encoding. spatial_dim: Spatial dimension (2D, 3D, ...) of sequence encoding. """ - super(BasicEmbeddingA, self).__init__() + super(BasicEmbedding, self).__init__() self.mask = None # embeddings diff --git a/modules/token_embedding/composite_embedding_A.py b/modules/token_embedding/composite_embedding_A.py index 01589ee4a16dc9de7e5d03d825523e99e0ec27b9..b436dc0841d0e753d47bed837887eb28ecfe667c 100644 --- a/modules/token_embedding/composite_embedding_A.py +++ b/modules/token_embedding/composite_embedding_A.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence -from .basic_embedding_A import BasicEmbeddingA -from .convolution_embedding_A import ConvolutionEmbeddingA +from .basic_embedding import BasicEmbedding +from .convolution_embedding import ConvolutionEmbedding from .double_substitution_embedding import DoubleSubstitutionEmbedding from .substitution_embedding import SubstitutionEmbedding @@ -34,21 +34,21 @@ class CompositeEmbeddingA(nn.Module): modules = [] if resolution >= 2: - modules += [BasicEmbeddingA(**kwargs)] + modules += [BasicEmbedding(**kwargs)] if resolution >= 4: - modules += [BasicEmbeddingA(**kwargs)] + modules += [BasicEmbedding(**kwargs)] if resolution >= 8: - modules += [BasicEmbeddingA(**kwargs)] + modules += [BasicEmbedding(**kwargs)] if resolution >= 16: - modules += [ConvolutionEmbeddingA(**kwargs, conv_size=2 ** (spatial_dim - 1))] + modules += [ConvolutionEmbedding(**kwargs, conv_size=4)] if resolution >= 32: - modules += [ConvolutionEmbeddingA(**kwargs, conv_size=2 ** spatial_dim)] + modules += [ConvolutionEmbedding(**kwargs, conv_size=8)] if resolution >= 64: - modules += [SubstitutionEmbedding(**kwargs, conv_size=2 ** spatial_dim)] + modules += [SubstitutionEmbedding(**kwargs, conv_size=8)] if resolution >= 128: - modules += [DoubleSubstitutionEmbedding(**kwargs, conv_size=2 ** spatial_dim)] + modules += [DoubleSubstitutionEmbedding(**kwargs, conv_size=8)] if resolution >= 256: - modules += [DoubleSubstitutionEmbedding(**kwargs, conv_size=2 ** spatial_dim)] + modules += [DoubleSubstitutionEmbedding(**kwargs, conv_size=8)] # embeddings self.embedding = encoding diff --git a/modules/token_embedding/composite_embedding_B.py b/modules/token_embedding/composite_embedding_B.py index c8c840fcb2645c3a0084195d5613608569946311..29491648b7f5e8f0e75653d6d159a9bfca16d945 100644 --- a/modules/token_embedding/composite_embedding_B.py +++ b/modules/token_embedding/composite_embedding_B.py @@ -2,9 +2,8 @@ import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence -from .basic_embedding_A import BasicEmbeddingA -from .convolution_embedding_A import ConvolutionEmbeddingA -from .double_substitution_embedding import DoubleSubstitutionEmbedding +from .basic_embedding import BasicEmbedding +from .convolution_embedding import ConvolutionEmbedding from .substitution_embedding import SubstitutionEmbedding @@ -34,19 +33,17 @@ class CompositeEmbeddingB(nn.Module): modules = [] if resolution >= 2: - modules += [BasicEmbeddingA(**kwargs)] + modules += [BasicEmbedding(**kwargs)] if resolution >= 4: - modules += [BasicEmbeddingA(**kwargs)] + modules += [BasicEmbedding(**kwargs)] if resolution >= 8: - modules += [BasicEmbeddingA(**kwargs)] + modules += [BasicEmbedding(**kwargs)] if resolution >= 16: - modules += [ConvolutionEmbeddingA(**kwargs, conv_size=2 ** (spatial_dim - 2))] + modules += [BasicEmbedding(**kwargs)] if resolution >= 32: - modules += [ConvolutionEmbeddingA(**kwargs, conv_size=2 ** spatial_dim)] + modules += [ConvolutionEmbedding(**kwargs, conv_size=8)] if resolution >= 64: - modules += [SubstitutionEmbedding(**kwargs, conv_size=2 ** spatial_dim)] - if resolution >= 128: - modules += [DoubleSubstitutionEmbedding(**kwargs, conv_size=2 ** spatial_dim)] + modules += [SubstitutionEmbedding(**kwargs, conv_size=8)] # embeddings self.embedding = encoding @@ -93,7 +90,7 @@ class CompositeEmbeddingB(nn.Module): val_seq = torch.cat([val[dep == (layer_depth - 1)], val[dep == layer_depth]]) dep_seq = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]) pos_seq = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]) - elif layer_depth == 7: # third-, second- and last layer + elif layer_depth in (7, 8): # third-, second- and last layer emb_seq = torch.cat( [emb[dep == (layer_depth - 2)], emb[dep == (layer_depth - 1)], emb[dep == layer_depth]] ) diff --git a/modules/token_embedding/composite_embedding_C.py b/modules/token_embedding/composite_embedding_C.py index 02594b64c0022cdb9bcae3213001818fbb119b3c..66b1b2bef063b0874b0323e9cc80b65239ccb173 100644 --- a/modules/token_embedding/composite_embedding_C.py +++ b/modules/token_embedding/composite_embedding_C.py @@ -2,9 +2,8 @@ import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence -from .basic_embedding_A import BasicEmbeddingA -from .convolution_embedding_A import ConvolutionEmbeddingA -from .double_substitution_embedding import DoubleSubstitutionEmbedding +from .basic_embedding import BasicEmbedding +from .convolution_embedding import ConvolutionEmbedding from .substitution_embedding import SubstitutionEmbedding @@ -34,17 +33,17 @@ class CompositeEmbeddingC(nn.Module): modules = [] if resolution >= 2: - modules += [BasicEmbeddingA(**kwargs)] + modules += [BasicEmbedding(**kwargs)] if resolution >= 4: - modules += [BasicEmbeddingA(**kwargs)] + modules += [BasicEmbedding(**kwargs)] if resolution >= 8: - modules += [ConvolutionEmbeddingA(**kwargs, conv_size=2 ** (spatial_dim - 1))] + modules += [ConvolutionEmbedding(**kwargs, conv_size=2)] if resolution >= 16: - modules += [ConvolutionEmbeddingA(**kwargs, conv_size=2 ** spatial_dim)] + modules += [ConvolutionEmbedding(**kwargs, conv_size=4)] if resolution >= 32: - modules += [SubstitutionEmbedding(**kwargs, conv_size=2 ** spatial_dim)] + modules += [ConvolutionEmbedding(**kwargs, conv_size=8)] if resolution >= 64: - modules += [DoubleSubstitutionEmbedding(**kwargs, conv_size=2 ** spatial_dim)] + modules += [SubstitutionEmbedding(**kwargs, conv_size=4)] # embeddings self.embedding = encoding @@ -81,17 +80,17 @@ class CompositeEmbeddingC(nn.Module): break # reached max depth layer # filter layers for embeddings - if layer_depth < 5: # only last layer + if layer_depth < 6: # only last layer emb_seq = emb[dep == layer_depth] val_seq = val[dep == layer_depth] dep_seq = dep[dep == layer_depth] pos_seq = pos[dep == layer_depth] - elif layer_depth == 5: # penultimate and last layer + elif layer_depth == 6: # penultimate and last layer emb_seq = torch.cat([emb[dep == (layer_depth - 1)], emb[dep == layer_depth]]) val_seq = torch.cat([val[dep == (layer_depth - 1)], val[dep == layer_depth]]) dep_seq = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]) pos_seq = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]) - elif layer_depth == 6: # third-, second- and last layer + elif layer_depth in (7, 8): # third-, second- and last layer emb_seq = torch.cat( [emb[dep == (layer_depth - 2)], emb[dep == (layer_depth - 1)], emb[dep == layer_depth]] ) diff --git a/modules/token_embedding/composite_embedding_D.py b/modules/token_embedding/composite_embedding_D.py index bd441356bd43682b94a371c2d5ea8f62705fd75e..ab1c0dc16627acd0392f000c20f1cb72fcf3d982 100644 --- a/modules/token_embedding/composite_embedding_D.py +++ b/modules/token_embedding/composite_embedding_D.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence -from .basic_embedding_A import BasicEmbeddingA -from .convolution_embedding_A import ConvolutionEmbeddingA +from .basic_embedding import BasicEmbedding +from .convolution_embedding import ConvolutionEmbedding from .double_substitution_embedding import DoubleSubstitutionEmbedding from .substitution_embedding import SubstitutionEmbedding @@ -34,17 +34,17 @@ class CompositeEmbeddingD(nn.Module): modules = [] if resolution >= 2: - modules += [BasicEmbeddingA(**kwargs)] + modules += [BasicEmbedding(**kwargs)] if resolution >= 4: - modules += [BasicEmbeddingA(**kwargs)] + modules += [BasicEmbedding(**kwargs)] if resolution >= 8: - modules += [ConvolutionEmbeddingA(**kwargs, conv_size=2)] + modules += [ConvolutionEmbedding(**kwargs, conv_size=4)] if resolution >= 16: - modules += [ConvolutionEmbeddingA(**kwargs, conv_size=8)] + modules += [ConvolutionEmbedding(**kwargs, conv_size=8)] if resolution >= 32: modules += [SubstitutionEmbedding(**kwargs, conv_size=4)] if resolution >= 64: - modules += [DoubleSubstitutionEmbedding(**kwargs, conv_size=2)] + modules += [SubstitutionEmbedding(**kwargs, conv_size=8)] if resolution >= 128: modules += [DoubleSubstitutionEmbedding(**kwargs, conv_size=4)] if resolution >= 256: @@ -90,12 +90,12 @@ class CompositeEmbeddingD(nn.Module): val_seq = val[dep == layer_depth] dep_seq = dep[dep == layer_depth] pos_seq = pos[dep == layer_depth] - elif layer_depth == 5: # penultimate and last layer + elif layer_depth in (5, 6): # penultimate and last layer emb_seq = torch.cat([emb[dep == (layer_depth - 1)], emb[dep == layer_depth]]) val_seq = torch.cat([val[dep == (layer_depth - 1)], val[dep == layer_depth]]) dep_seq = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]) pos_seq = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]) - elif layer_depth in (6, 7, 8): # third-, second- and last layer + elif layer_depth in (7, 8): # third-, second- and last layer emb_seq = torch.cat( [emb[dep == (layer_depth - 2)], emb[dep == (layer_depth - 1)], emb[dep == layer_depth]] ) diff --git a/modules/token_embedding/convolution_embedding_A.py b/modules/token_embedding/convolution_embedding.py similarity index 96% rename from modules/token_embedding/convolution_embedding_A.py rename to modules/token_embedding/convolution_embedding.py index 11435ceba56db668c6061594ce42f993841a6316..65fb7d1d38635e39047f237d3ee8798303fddec6 100644 --- a/modules/token_embedding/convolution_embedding_A.py +++ b/modules/token_embedding/convolution_embedding.py @@ -4,7 +4,7 @@ from utils.masks import padding_mask from ..utils import Convolution -class ConvolutionEmbeddingA(nn.Module): +class ConvolutionEmbedding(nn.Module): def __init__(self, encoding, num_vocab, embed_dim, resolution, spatial_dim, conv_size, **_): """ Performs an embedding of token sequences into an embedding space of higher dimension. @@ -21,7 +21,7 @@ class ConvolutionEmbeddingA(nn.Module): spatial_dim: Spatial dimension (2D, 3D, ...) of sequence encoding. conv_size: Convolution kernel size and stride. """ - super(ConvolutionEmbeddingA, self).__init__() + super(ConvolutionEmbedding, self).__init__() self.chunk_size = conv_size self.mask = None diff --git a/modules/token_embedding/embedding_factory.py b/modules/token_embedding/embedding_factory.py index 37cec5d55191a29e4acc5a72bdd0eebfb1d2351a..ade0336f8554507c871acfbf0dbdafd435322972 100644 --- a/modules/token_embedding/embedding_factory.py +++ b/modules/token_embedding/embedding_factory.py @@ -2,12 +2,12 @@ import torch.nn as nn from modules.utils import Embedding, PositionalEncodingLearned, PositionalEncodingLearnedLookAhead, \ PositionalEncodingLearnedLookAheadSplit -from .basic_embedding_A import BasicEmbeddingA +from .basic_embedding import BasicEmbedding from .composite_embedding_A import CompositeEmbeddingA from .composite_embedding_B import CompositeEmbeddingB from .composite_embedding_C import CompositeEmbeddingC from .composite_embedding_D import CompositeEmbeddingD -from .convolution_embedding_A import ConvolutionEmbeddingA +from .convolution_embedding import ConvolutionEmbedding from .double_substitution_embedding import DoubleSubstitutionEmbedding from .multi_conv_embedding_A import MultiConvolutionEmbeddingA from .substitution_embedding import SubstitutionEmbedding @@ -51,15 +51,15 @@ def _create_embedding(name, positional_encoding, num_vocab, embed_dim, resolutio } if name in ('basic', 'basic_A'): - return BasicEmbeddingA(**kwargs) + return BasicEmbedding(**kwargs) elif name == 'discrete_transformation': kwargs['num_vocab'] = num_vocab ** 2 ** spatial_dim - return BasicEmbeddingA(**kwargs) + return BasicEmbedding(**kwargs) elif name in ('half_conv', 'half_conv_A'): kwargs['conv_size'] = 2 ** (spatial_dim - 1) - return ConvolutionEmbeddingA(**kwargs) + return ConvolutionEmbedding(**kwargs) elif name in ('single_conv', 'single_conv_A'): - return ConvolutionEmbeddingA(**kwargs) + return ConvolutionEmbedding(**kwargs) elif name == 'multi_conv_A': return MultiConvolutionEmbeddingA(**kwargs) elif name == 'substitution': diff --git a/modules/utils/block_convolution.py b/modules/utils/block_convolution.py index e0c7363cdc459cebd4d1c22e42b4f55d09ca7c5b..31a01f6bb9c1c1c58b50cfb3830dbcf6ef918a87 100644 --- a/modules/utils/block_convolution.py +++ b/modules/utils/block_convolution.py @@ -1,3 +1,5 @@ +import math + import torch import torch.nn as nn @@ -16,8 +18,11 @@ class BlockConvolution(nn.Module): self.block_size = block_size self.convolutions = nn.ModuleList([ - nn.Conv1d(source_dim, target_dim, (i + 1,), block_size, bias=True) for i in range(block_size-1) + nn.Conv1d(source_dim, target_dim, (i + 1,), block_size, bias=False) for i in range(block_size - 1) ]) + sigma = math.sqrt(1. / (block_size * source_dim)) + self.bias = nn.Parameter(torch.empty(block_size)) + nn.init.uniform_(self.bias, -sigma, sigma) def forward(self, seq_vector): """ Convolute tokens to reduce sequence length @@ -30,7 +35,9 @@ class BlockConvolution(nn.Module): """ out = torch.zeros_like(seq_vector) + out[:, ::self.block_size] += self.bias[0] for i, conv in enumerate(self.convolutions): out[:, 1 + i::self.block_size] = conv(seq_vector.transpose(1, 2)).transpose(1, 2) + out[:, 1 + i::self.block_size] += self.bias[1 + i] return out diff --git a/utils/loss/loss_factory.py b/utils/loss/loss_factory.py index eaa1d12c1bef26249d7917233a7bbc50fa4bc540..9faa90f274d5a1caadc97d4a8c831eec6c4d6721 100644 --- a/utils/loss/loss_factory.py +++ b/utils/loss/loss_factory.py @@ -27,11 +27,11 @@ def create_loss(name, ignore_index, max_depth, spatial_dim): elif name == 'depth_cross_entropy_A': return DepthWeightedCrossEntropyLoss(**kwargs, basis=0.8) elif name == 'depth_cross_entropy_B': - return DepthWeightedCrossEntropyLoss(**kwargs, basis=0.6) + return DepthWeightedCrossEntropyLoss(**kwargs, basis=0.5) elif name == 'depth_cross_entropy_C': return DepthWeightedCrossEntropyLoss(**kwargs, basis=0.4) elif name == 'depth_cross_entropy_D': - return DepthWeightedCrossEntropyLoss(**kwargs, basis=0.3) + return DepthWeightedCrossEntropyLoss(**kwargs, basis=0.25) elif name == 'depth_cross_entropy_E': return DepthWeightedCrossEntropyLoss(**kwargs, basis=0.125) else: