Skip to content
Snippets Groups Projects
Commit 842a635e authored by Moritz Ibing's avatar Moritz Ibing
Browse files

sum fixes in embeddings + head, as well as new composite embedding and loss

parent 575f0e5b
Branches
No related tags found
No related merge requests found
......@@ -15,6 +15,7 @@ class CheckSequenceLenghtTransform():
'composite_A': [0, 0, 0, 0, 0, 1, 2, 2],
'composite_B': [0, 0, 0, 0, 0, 1, 2, 2],
'composite_C': [0, 0, 0, 0, 1, 2],
'composite_D': [0, 0, 0, 0, 1, 2, 2, 2],
}
_convolution_factor_map = {
......@@ -31,6 +32,7 @@ class CheckSequenceLenghtTransform():
'composite_A': [1, 1, 1, 4, 8, 8, 8, 8],
'composite_B': [1, 1, 1, 2, 8, 8, 8, 8],
'composite_C': [1, 1, 4, 8, 8, 8],
'composite_D': [1, 1, 2, 8, 4, 2, 4, 8],
}
def __init__(self, num_positions, embedding):
......
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from .convolution_head_A import ConvolutionHeadA, ConvolutionHeadAutoregressive
from .double_substitution_head import DoubleSubstitutionHead, DoubleSubstitutionHeadAutoRegressive
from .linear_head import LinearHead
from .substitution_head import SubstitutionHead, SubstitutionHeadAutoregressive
class CompositeHeadD(nn.Module):
def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution, **_):
""" Performs a transformation from transformer latent space into target value logits.
Uses a different heads for each depth layer, possibly increasing the overall sequence lenght.
Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
Args:
num_vocab: Number of different target token values (exclusive padding token '0').
embed_dim: Dimension of the latent embedding space of the transformer.
head_dim: Size of embedding dimensions used in the head layers.
n_layer: Number of layers used in each linear or convolution block.
resolution: Spatial resolution of sequence encoding.
"""
super(CompositeHeadD, self).__init__()
kwargs = {
"spatial_encoding": spatial_encoding,
"num_vocab": num_vocab,
"embed_dim": embed_dim,
"head_dim": head_dim,
"n_layer": n_layer
}
modules = []
if resolution >= 2:
modules += [LinearHead(**kwargs)]
if resolution >= 4:
modules += [LinearHead(**kwargs)]
if resolution >= 8:
modules += [ConvolutionHeadA(**kwargs, conv_size=2)]
if resolution >= 16:
modules += [ConvolutionHeadA(**kwargs, conv_size=8)]
if resolution >= 32:
modules += [SubstitutionHead(**kwargs, conv_size=4)]
if resolution >= 64:
modules += [DoubleSubstitutionHead(**kwargs, conv_size=2)]
if resolution >= 128:
modules += [DoubleSubstitutionHead(**kwargs, conv_size=4)]
if resolution >= 256:
modules += [DoubleSubstitutionHead(**kwargs, conv_size=8)]
# embeddings
self.heads = nn.ModuleList(modules)
self.reduction_factor = {
1: 1,
2: 1,
3: 2,
4: 8,
5: 4, # Note: 'substitution'
6: 2, # Note: 'double_substitution'
7: 4, # Note: 'double_substitution'
8: 8, # Note: 'double_substitution'
}
def forward(self, x, value, depth, position):
""" Transforms the output of the transformer target value logits.
Args:
x: Output of the transformer, the latent vector [N, T, E].
value: Target value token sequence [N, T].
depth: Target depth token sequence [N, T].
position: Target position token sequence [N, T, A].
Return
Logits of target value sequence.
"""
batch_depth = torch.max(depth)
out = []
# process each sample individually
for latent_vec, val, dep, pos in zip(x, value, depth, position):
logits = torch.tensor([], device=x.device)
vector_idx = 0
# compute logits layerwise
for layer_idx, head in enumerate(self.heads):
layer_depth = layer_idx + 1
if layer_depth > batch_depth:
break # reached max depth layer
if layer_depth < 5:
# get value, depth and position sequence of current layer
layer_val = val[dep == layer_depth]
layer_dep = dep[dep == layer_depth]
layer_pos = pos[dep == layer_depth]
# compute number of vectors in latent vector of current layer
num_vectors = torch.sum(dep == layer_depth) // self.reduction_factor[layer_depth]
elif layer_depth == 5: # handle substitution
# get value, depth and position sequence of previous and current layer
layer_val = torch.cat([val[dep == (layer_depth - 1)], val[dep == layer_depth]])
layer_dep = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]])
layer_pos = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]])
# compute number of vectors in latent vector of current layer
num_vectors = torch.sum(dep == (layer_depth - 1)) // self.reduction_factor[layer_depth]
elif layer_depth in (6, 7, 8): # handle double substitution
# get value, depth and position sequence of previous and current layer
layer_val = torch.cat(
[
val[dep == (layer_depth - 2)],
val[dep == (layer_depth - 1)],
val[dep == layer_depth],
]
)
layer_dep = torch.cat(
[
dep[dep == (layer_depth - 2)],
dep[dep == (layer_depth - 1)],
dep[dep == layer_depth],
]
)
layer_pos = torch.cat(
[
pos[dep == (layer_depth - 2)],
pos[dep == (layer_depth - 1)],
pos[dep == layer_depth],
]
)
# compute number of vectors in latent vector of current layer
num_vectors = torch.sum(dep == (layer_depth - 2)) // self.reduction_factor[layer_depth]
# filter latent vector of current layer
layer_vec = latent_vec[vector_idx:vector_idx + num_vectors]
# handle clipped values in transformer
if len(layer_vec) == 0:
continue
# compute layer logits
layer_logits = head(
layer_vec.unsqueeze(0),
layer_val.unsqueeze(0),
layer_dep.unsqueeze(0),
layer_pos.unsqueeze(0),
)[0]
logits = torch.cat([logits, layer_logits])
# discard processed tokens
vector_idx += num_vectors
out += [logits]
# pad embedding sequence
return pad_sequence(out, batch_first=True, padding_value=0.0)
class CompositeHeadAutoregressiveD(CompositeHeadD):
def __init__(self, spatial_encoding, num_vocab, embed_dim, head_dim, n_layer, resolution, **_):
""" Performs a transformation from transformer latent space into target value logits.
Uses a different heads for each depth layer, possibly increasing the overall sequence lenght.
Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
Args:
num_vocab: Number of different target token values (exclusive padding token '0').
embded_dim: Dimension of the latent embedding space of the transformer.
head_dim: Size of embedding dimensions used in the head layers.
n_layer: Number of layers used in each linear or convolution block.
resolution: Spatial resolution of sequence encoding.
"""
super(CompositeHeadAutoregressiveD, self).__init__(spatial_encoding, num_vocab, embed_dim, head_dim, n_layer,
resolution, **_)
kwargs = {
"spatial_encoding": spatial_encoding,
"num_vocab": num_vocab,
"embed_dim": embed_dim,
"head_dim": head_dim,
"n_layer": n_layer
}
modules = []
if resolution >= 2:
modules += [LinearHead(**kwargs)]
if resolution >= 4:
modules += [LinearHead(**kwargs)]
if resolution >= 8:
modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=2)]
if resolution >= 16:
modules += [ConvolutionHeadAutoregressive(**kwargs, conv_size=8)]
if resolution >= 32:
modules += [SubstitutionHeadAutoregressive(**kwargs, conv_size=4)]
if resolution >= 64:
modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=2)]
if resolution >= 128:
modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=4)]
if resolution >= 256:
modules += [DoubleSubstitutionHeadAutoRegressive(**kwargs, conv_size=8)]
# embeddings
self.heads = nn.ModuleList(modules)
self.reduction_factor = {
1: 1,
2: 1,
3: 2,
4: 8,
5: 4, # Note: 'substitution'
6: 2, # Note: 'double_substitution'
7: 4, # Note: 'double_substitution'
8: 8, # Note: 'double_substitution'
}
......@@ -127,7 +127,6 @@ class DoubleSubstitutionHeadAutoRegressive(nn.Module):
"""
super(DoubleSubstitutionHeadAutoRegressive, self).__init__()
self.head_dim = head_dim
self.conv_size = conv_size
deconvolution_2 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)]
for i in range(n_layer - 1):
......@@ -214,12 +213,12 @@ class DoubleSubstitutionHeadAutoRegressive(nn.Module):
emb_1 = torch.zeros((batch_size, torch.max(len_1), self.head_dim), dtype=torch.float, device=value.device)
# substitute all mixed token embeddings of penultimate layer, with token embeddings of last layer
emb_1[val_1 == 2] = emb_0[:, (self.conv_size - 1)::self.conv_size] # [N, T1, C]
emb_1[val_1 == 2] = emb_0[:, 7::8] # [N, T1, C]
emb_1 = self.convolution_1(emb_1)
emb_2 = torch.zeros((batch_size, torch.max(len_2), self.head_dim), dtype=torch.float, device=value.device)
# substitute all mixed token embeddings of third to last layer, with token embeddings of penultimate layer
emb_2[val_2 == 2] = emb_1[:, (self.conv_size - 1)::self.conv_size] # [N, T1, C]
emb_2[val_2 == 2] = emb_1[:, 7::8] # [N, T1, C]
emb_2 = self.convolution_2(emb_2)
# create intermediate list to hold vectors
......
......@@ -5,6 +5,7 @@ from modules.utils import PositionalEncodingLearned, PositionalEncodingLearnedLo
from .composite_head_A import CompositeHeadA, CompositeHeadAutoregressiveA
from .composite_head_B import CompositeHeadB
from .composite_head_C import CompositeHeadC
from .composite_head_D import CompositeHeadD, CompositeHeadAutoregressiveD
from .convolution_head_A import ConvolutionHeadA
from .double_substitution_head import DoubleSubstitutionHead
from .linear_head import LinearHead
......@@ -75,6 +76,10 @@ def _create_head(name, positional_encoding, num_vocab, embed_dim, head_dim, n_la
return CompositeHeadB(**kwargs)
elif name in ('composite_C'):
return CompositeHeadC(**kwargs)
elif name in ('composite_D'):
return CompositeHeadD(**kwargs)
elif name in ('composite_autoregressive_D'):
return CompositeHeadAutoregressiveD(**kwargs)
else:
raise ValueError(f"ERROR: {name} head not implemented.")
......
......@@ -109,7 +109,6 @@ class SubstitutionHeadAutoregressive(nn.Module):
"""
super(SubstitutionHeadAutoregressive, self).__init__()
self.head_dim = head_dim
self.conv_size = conv_size
deconvolution_1 = [nn.GELU(), Deconvolution(embed_dim, head_dim, conv_size)]
for i in range(n_layer - 1):
......@@ -184,7 +183,7 @@ class SubstitutionHeadAutoregressive(nn.Module):
emb_1 = torch.zeros((batch_size, torch.max(len_1), self.head_dim), dtype=torch.float, device=value.device)
# substitite all mixed token embeddings of penultimate layer, with token embeddings of last layer
emb_1[val_1 == 2] = emb_0[:, (self.conv_size - 1)::self.conv_size] # [N, T1, C]
emb_1[val_1 == 2] = emb_0[:, 7::8] # [N, T1, C]
emb_1 = self.convolution_1(emb_1)
x_0 = torch.zeros((batch_size, torch.max(mix_1), self.head_dim), device=value.device)
......
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from .basic_embedding_A import BasicEmbeddingA
from .convolution_embedding_A import ConvolutionEmbeddingA
from .double_substitution_embedding import DoubleSubstitutionEmbedding
from .substitution_embedding import SubstitutionEmbedding
class CompositeEmbeddingD(nn.Module):
def __init__(self, encoding, num_vocab, embed_dim, resolution, spatial_dim, **_):
""" Performs an embedding of token sequences into an embedding space of higher dimension.
Uses a different embedding for each depth layer, possibly reducing the overall sequence lenght.
Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
Args:
encoding: Defines how the tokens are encoded before being reduced
num_vocab: Number of different token values (exclusive padding token '0').
embded_dim: Dimension of returned embedding space.
resolution: Spatial resolution of sequence encoding.
spatial_dim: Spatial dimension (2D, 3D, ...) of sequence encoding.
"""
super(CompositeEmbeddingD, self).__init__()
kwargs = {
"encoding": None,
"num_vocab": num_vocab,
"embed_dim": embed_dim,
"resolution": resolution,
"spatial_dim": spatial_dim,
}
modules = []
if resolution >= 2:
modules += [BasicEmbeddingA(**kwargs)]
if resolution >= 4:
modules += [BasicEmbeddingA(**kwargs)]
if resolution >= 8:
modules += [ConvolutionEmbeddingA(**kwargs, conv_size=2)]
if resolution >= 16:
modules += [ConvolutionEmbeddingA(**kwargs, conv_size=8)]
if resolution >= 32:
modules += [SubstitutionEmbedding(**kwargs, conv_size=4)]
if resolution >= 64:
modules += [DoubleSubstitutionEmbedding(**kwargs, conv_size=2)]
if resolution >= 128:
modules += [DoubleSubstitutionEmbedding(**kwargs, conv_size=4)]
if resolution >= 256:
modules += [DoubleSubstitutionEmbedding(**kwargs, conv_size=8)]
# embeddings
self.embedding = encoding
self.reductions = nn.ModuleList(modules)
def reduce(self, embedding, value, depth, position):
""" Transform sequences of token into an embedding space.
Args:
embdedding: Embedding sequence
value: Value token sequence.
depth: Depth token sequence.
position: Position token sequence.
Return:
Token sequence in the embedding space.
"""
batch_depth = torch.max(depth)
batch_size = len(value)
x = []
padding_mask = []
# process each sample individually
for i in range(batch_size):
# extract value, depth and position sequence of current sample
emb, val, dep, pos = embedding[i], value[i], depth[i], position[i]
b_emb = torch.tensor([], device=value.device)
# embed layerwise
for layer_idx, reduction in enumerate(self.reductions):
layer_depth = layer_idx + 1
if layer_depth > batch_depth:
break # reached max depth layer
# filter layers for embeddings
if layer_depth < 5: # only last layer
emb_seq = emb[dep == layer_depth]
val_seq = val[dep == layer_depth]
dep_seq = dep[dep == layer_depth]
pos_seq = pos[dep == layer_depth]
elif layer_depth == 5: # penultimate and last layer
emb_seq = torch.cat([emb[dep == (layer_depth - 1)], emb[dep == layer_depth]])
val_seq = torch.cat([val[dep == (layer_depth - 1)], val[dep == layer_depth]])
dep_seq = torch.cat([dep[dep == (layer_depth - 1)], dep[dep == layer_depth]])
pos_seq = torch.cat([pos[dep == (layer_depth - 1)], pos[dep == layer_depth]])
elif layer_depth in (6, 7, 8): # third-, second- and last layer
emb_seq = torch.cat(
[emb[dep == (layer_depth - 2)], emb[dep == (layer_depth - 1)], emb[dep == layer_depth]]
)
val_seq = torch.cat(
[val[dep == (layer_depth - 2)], val[dep == (layer_depth - 1)], val[dep == layer_depth]]
)
dep_seq = torch.cat(
[dep[dep == (layer_depth - 2)], dep[dep == (layer_depth - 1)], dep[dep == layer_depth]]
)
pos_seq = torch.cat(
[pos[dep == (layer_depth - 2)], pos[dep == (layer_depth - 1)], pos[dep == layer_depth]]
)
# compute layer embedding
layer_emb = reduction.reduce(
emb_seq.unsqueeze(0),
val_seq.unsqueeze(0),
dep_seq.unsqueeze(0),
pos_seq.unsqueeze(0),
)[0]
b_emb = torch.cat([b_emb, layer_emb])
# append embedding
x += [b_emb]
padding_mask += [torch.zeros(b_emb.shape[0], dtype=torch.bool, device=value.device)]
# create padding mask
self.mask = pad_sequence(padding_mask, batch_first=True, padding_value=1)
# pad embedding sequence
return pad_sequence(x, batch_first=True, padding_value=0.0)
def forward(self, value, depth, position):
""" Transform sequences of token into an embedding space.
Args:
value: Value token sequence.
depth: Depth token sequence.
position: Position token sequence.
Return:
Token sequence in the embedding space.
"""
return self.reduce(self.embedding(value, depth, position), value, depth, position)
def padding_mask(self):
""" Returns a padding mask, where padding tokens '0' of the value sequence are masked out. """
return self.mask
......@@ -32,8 +32,8 @@ class DoubleSubstitutionEmbedding(nn.Module):
self.embedding = encoding
# convolutions
self.convolution_0 = Convolution(embed_dim, embed_dim, conv_size)
self.convolution_1 = Convolution(embed_dim, embed_dim, conv_size)
self.convolution_0 = Convolution(embed_dim, embed_dim, 8)
self.convolution_1 = Convolution(embed_dim, embed_dim, 8)
self.convolution_2 = Convolution(embed_dim, embed_dim, conv_size)
def reduce(self, embedding, value, depth, position):
......@@ -103,12 +103,12 @@ class DoubleSubstitutionEmbedding(nn.Module):
# convolute embedded tokens of last layer
y_0 = self.convolution_0(x_0) # [N, S'_0, E // 4]
# substitite all mixed token embeddings of second-last layer, with token embeddings of last layer
x_1[val_1 == 2] = y_0[val_0[:, ::self.conv_size] != 0] # [N, S_1, E // 4]
x_1[val_1 == 2] = y_0[val_0[:, ::8] != 0] # [N, S_1, E // 4]
# convolute substituted tokens of second-last layer
y_1 = self.convolution_1(x_1.contiguous()) # [N, S'_1, E // 4]
# substitite all mixed token embeddings of third-last layer, with token embeddings of second-last layer
x_2[val_2 == 2] = y_1[val_1[:, ::self.conv_size] != 0] # [N, S_2, E // 2]
x_2[val_2 == 2] = y_1[val_1[:, ::8] != 0] # [N, S_2, E // 2]
# convolute substituted tokens of second-last layer
return self.convolution_2(x_2.contiguous()) # [N, S'_2, E]
......
......@@ -6,6 +6,7 @@ from .basic_embedding_A import BasicEmbeddingA
from .composite_embedding_A import CompositeEmbeddingA
from .composite_embedding_B import CompositeEmbeddingB
from .composite_embedding_C import CompositeEmbeddingC
from .composite_embedding_D import CompositeEmbeddingD
from .convolution_embedding_A import ConvolutionEmbeddingA
from .double_substitution_embedding import DoubleSubstitutionEmbedding
from .multi_conv_embedding_A import MultiConvolutionEmbeddingA
......@@ -71,6 +72,8 @@ def _create_embedding(name, positional_encoding, num_vocab, embed_dim, resolutio
return CompositeEmbeddingB(**kwargs)
elif name in ('composite_C'):
return CompositeEmbeddingC(**kwargs)
elif name in ('composite_D'):
return CompositeEmbeddingD(**kwargs)
else:
raise ValueError(f"ERROR: {name} embedding not implemented.")
......
......@@ -24,7 +24,7 @@ class SubstitutionEmbedding(nn.Module):
conv_size: Convolution kernel size and stride.
"""
super(SubstitutionEmbedding, self).__init__()
self.chunck_size = conv_size
self.conv_size = conv_size
self.spatial_dim = spatial_dim
self.mask = None
......@@ -32,7 +32,7 @@ class SubstitutionEmbedding(nn.Module):
self.embedding = encoding
# convolutions
self.convolution_0 = Convolution(embed_dim, embed_dim, conv_size)
self.convolution_0 = Convolution(embed_dim, embed_dim, 8)
self.convolution_1 = Convolution(embed_dim, embed_dim, conv_size)
def reduce(self, embedding, value, depth, position):
......@@ -83,13 +83,13 @@ class SubstitutionEmbedding(nn.Module):
pos_0[i, :len_0[i]] = position[i, len_1[i]:len_1[i] + len_0[i]]
# precompute padding mask
self.mask = padding_mask(val_1[:, ::self.chunck_size], device=value.device)
self.mask = padding_mask(val_1[:, ::self.conv_size], device=value.device)
# convolute embedded tokens of last layer
y_0 = self.convolution_0(x_0) # [N, T2', C]
# substitite all mixed token embeddings of penultimate layer, with token embeddings of last layer
x_1[val_1 == 2] = y_0[val_0[:, ::self.chunck_size] != 0] # [N, T1, C]
x_1[val_1 == 2] = y_0[val_0[:, ::8] != 0] # [N, T1, C]
# convolute substituted tokens of penultimate layer
return self.convolution_1(x_1.contiguous()) # [N, T1', E]
......
......@@ -32,5 +32,7 @@ def create_loss(name, ignore_index, max_depth, spatial_dim):
return DepthWeightedCrossEntropyLoss(**kwargs, basis=0.4)
elif name == 'depth_cross_entropy_D':
return DepthWeightedCrossEntropyLoss(**kwargs, basis=0.3)
elif name == 'depth_cross_entropy_E':
return DepthWeightedCrossEntropyLoss(**kwargs, basis=0.125)
else:
raise ValueError(f"ERROR: {name} loss not implemented.")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment