Skip to content
Snippets Groups Projects
Commit 3089a2b7 authored by Gregor Kobsik's avatar Gregor Kobsik
Browse files

formatt embedding/head factory

- move parameters to kwargs variable
- improve readability / reduce maintenance
parent 8d7e9c79
No related branches found
No related tags found
No related merge requests found
......@@ -22,18 +22,26 @@ def _create_embedding(name, num_vocab, embed_dim, resolution, spatial_dim):
Return:
Token embedding initialised with specified parameters.
"""
kwargs = {
"num_vocab": num_vocab,
"embed_dim": embed_dim,
"resolution": resolution,
"spatial_dim": spatial_dim,
}
if name in ('basic', 'basic_A'):
return BasicEmbeddingA(num_vocab, embed_dim, resolution, spatial_dim)
return BasicEmbeddingA(**kwargs)
elif name == 'discrete_transformation':
return BasicEmbeddingA(num_vocab**2**spatial_dim + 1, embed_dim, resolution, spatial_dim)
kwargs['num_vocab'] = num_vocab**2**spatial_dim + 1
return BasicEmbeddingA(**kwargs)
elif name == 'half_conv_A':
return HalfConvolutionalEmbeddingA(num_vocab, embed_dim, resolution, spatial_dim)
return HalfConvolutionalEmbeddingA(**kwargs)
elif name in ('single_conv', 'single_conv_A'):
return SingleConvolutionalEmbeddingA(num_vocab, embed_dim, resolution, spatial_dim)
return SingleConvolutionalEmbeddingA(**kwargs)
elif name == 'multi_conv_A':
return MultiConvolutionalEmbeddingA(num_vocab, embed_dim, resolution, spatial_dim)
return MultiConvolutionalEmbeddingA(**kwargs)
elif name == 'substitution':
return SubstitutionEmbedding(num_vocab, embed_dim, resolution, spatial_dim)
return SubstitutionEmbedding(**kwargs)
else:
raise ValueError(f"ERROR: {name} embedding not implemented.")
......
......@@ -22,20 +22,27 @@ def _create_head(name, num_vocab, embed_dim, spatial_dim):
Return:
Generative head initialised with specified parameters.
"""
kwargs = {
"num_vocab": num_vocab,
"embed_dim": embed_dim,
"spatial_dim": spatial_dim,
}
if name in ('generative_basic', 'linear', 'basic'):
return LinearHead(num_vocab, embed_dim)
return LinearHead(**kwargs)
elif name in ('single_conv', 'single_conv_A'):
return SingleConvolutionalHeadA(num_vocab, embed_dim, spatial_dim)
return SingleConvolutionalHeadA(**kwargs)
elif name == 'split_B':
return SplitHeadB(num_vocab, embed_dim, spatial_dim)
return SplitHeadB(**kwargs)
elif name == 'substitution':
return SubstitutionHead(num_vocab, embed_dim, spatial_dim)
return SubstitutionHead(**kwargs)
elif name == 'discrete_transformation':
return LinearHead(num_vocab**2**spatial_dim + 1, embed_dim)
kwargs["num_vocab"] = num_vocab**2**spatial_dim + 1
return LinearHead(**kwargs)
elif name == 'half_conv_A':
return HalfConvolutionalHeadA(num_vocab, embed_dim, spatial_dim)
return HalfConvolutionalHeadA(**kwargs)
elif name == 'multi_conv_A':
return MultiConvolutionalHeadA(num_vocab, embed_dim, spatial_dim)
return MultiConvolutionalHeadA(**kwargs)
else:
raise ValueError(f"ERROR: {name} head not implemented.")
......
......@@ -2,7 +2,7 @@ import torch.nn as nn
class LinearHead(nn.Module):
def __init__(self, num_vocab, embed_dim):
def __init__(self, num_vocab, embed_dim, spatial_dim):
""" Performs a linear transformation from transformer latent space into target value logits.
Note: The token value '0' is reserved as a padding value, which does not propagate gradients.
......@@ -10,6 +10,7 @@ class LinearHead(nn.Module):
Args:
num_vocab: Number of different target token values (exclusive padding token '0').
embded_dim: Dimension of the latent embedding space of the transformer.
spatial_dim: unused.
"""
super(LinearHead, self).__init__()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment