Skip to content
Snippets Groups Projects
Commit 60b361ff authored by Gregor Kobsik's avatar Gregor Kobsik
Browse files

remove old/unused implementations of transformers

parent 7ec82c90
No related branches found
No related tags found
No related merge requests found
from modules.encoder_decoder.basic_encoder_decoder_module import BasicEncoderDecoderModule
__all__ = [
"BasicEncoderDecoderModule",
]
import torch
import torch.nn as nn
import torch.nn.functional as F
from masks import look_ahead_mask, padding_mask
class BasicEncoderDecoderModule(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
num_layers,
num_positions,
num_vocab,
spatial_dim,
tree_depth,
attention,
):
super(BasicEncoderDecoderModule, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_vocab = num_vocab
self.spatial_dim = spatial_dim
self.attention = attention
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
# embeddings
self.src_token_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0)
self.tgt_token_embedding = nn.Embedding(num_vocab**2**self.spatial_dim + 1, embed_dim, padding_idx=0)
self.depth_embedding = nn.Embedding(tree_depth + 1, embed_dim, padding_idx=0)
self.spatial_embeddings = nn.ModuleList(
[nn.Embedding(2**tree_depth + 1, embed_dim, padding_idx=0) for _ in range(spatial_dim)]
) # TODO: + 1 unneccessary
# transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=4 * embed_dim,
dropout=0.0,
activation='gelu',
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=num_layers,
norm=nn.LayerNorm(embed_dim),
)
# transformer decoder
decoder_layer = nn.TransformerDecoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=4 * embed_dim,
dropout=0.0,
activation='gelu',
)
self.transformer_decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=num_layers,
norm=nn.LayerNorm(embed_dim),
)
# final linear layer
self.head = nn.Linear(embed_dim, num_vocab**2**self.spatial_dim + 1, bias=False)
def _embed(self, x, depth, pos):
x = x + self.depth_embedding(depth)
for axis, spatial_embedding in enumerate(self.spatial_embeddings):
x = x + spatial_embedding(pos[:, :, axis])
return x
def encode(self, value, depth, pos):
"""
Expect input as shape:
value: (S, N)
depth: (S, N)
pos: (S, N, A)
shapes:
S: source length
N: batch size
E: embedding dimension
A: spatial dimension
"""
# embeddings -> [S, N, E]
x = self.src_token_embedding(value) # [S, N, E]
src = self._embed(x, depth, pos) # [S, N, E]
# encoder part of the transformer - compute memory
return self.transformer_encoder(src, src_key_padding_mask=padding_mask(value, device=src.device))
def decode(self, target, depth, pos, memory):
"""
Expect input as shape:
target: (T, N)
depth: (S, N)
pos: (S, N, A)
shapes:
T: target length
N: batch size
E: embedding dimension
A: spatial dimension
"""
tgt_len, batch = target.shape # [T, N]
# embeddings -> [T, N, E]
x = self.tgt_token_embedding(target) # [T, N, E]
tgt = self._embed(x, depth, pos) # [T, N, E]
# prepend start of sequence token -> [T, N, E]
sos = torch.ones(1, batch, self.embed_dim, device=tgt.device) * self.sos # [1, N, E]
tgt = torch.cat([sos, tgt[:-1]], axis=0) # [T, N, E]
# decoder part of the transformer
h = self.transformer_decoder(
tgt,
memory,
tgt_mask=look_ahead_mask(tgt_len, device=tgt.device),
tgt_key_padding_mask=padding_mask(target, device=tgt.device),
) # [T, N, E]
# return logits: [T, N, E] -> [T, num_vocab**2**spatial_dim]
return self.head(h) # [T, num_vocab**2**spatial_dim]
def forward(self, value, depth, pos, target):
"""
Expect input as shape:
value: (S, N)
depth: (S, N)
pos: (S, N, A)
target: (T, N)
shapes:
S: source length
T: target length
N: batch size
E: embedding dimension
A: spatial dimension
"""
# extract valid target sequence, if target is longer than last layer in input
tgt_len, _ = target.shape # [T, N]
tgt_idx = torch.argmax(depth)
max_tgt_len = len(depth[tgt_idx:])
tgt_depth = depth[tgt_idx:tgt_idx + tgt_len] + 1 # [T, N]
tgt_pos = pos[tgt_idx:tgt_idx + tgt_len] # [T, N, A]
# transformer encoder decoder
memory = self.encode(value, depth, pos) # [T, N, E]
output = self.decode(target[:max_tgt_len], tgt_depth, tgt_pos, memory) # [T, num_vocab**2**spatial_dim]
# pad output if target sequence was extracted
logits = F.pad(output, pad=(0, 0, 0, 0, 0, tgt_len - max_tgt_len), mode='constant', value=0)
# return logits
return logits
from modules.encoder_only.basic_transformer_module import BasicTransformerModule
from modules.encoder_only.fast_transformer_module import FastTransformerModule
from modules.encoder_only.performer_module import PerformerModule
from modules.encoder_only.reformer_module import ReformerModule
from modules.encoder_only.routing_transformer_module import RoutingTransformerModule
from modules.encoder_only.sinkhorn_transformer_module import SinkhornTransformerModule
from modules.encoder_only.linear_transformer_module import LinearTransformerModule
__all__ = [
"BasicTransformerModule",
"FastTransformerModule",
"PerformerModule",
"ReformerModule",
"RoutingTransformerModule",
"SinkhornTransformerModule",
"LinearTransformerModule",
]
import torch
import torch.nn as nn
from masks import look_ahead_mask, padding_mask, ancestor_mask
class BasicTransformerModule(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
num_layers,
num_positions,
num_vocab,
spatial_dim,
tree_depth,
attention,
):
super(BasicTransformerModule, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_vocab = num_vocab
self.spatial_dim = spatial_dim
self.attention = attention
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
# embeddings
self.token_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0)
self.depth_embedding = nn.Embedding(tree_depth + 1, embed_dim, padding_idx=0)
self.spatial_embeddings = nn.ModuleList(
[nn.Embedding(2**tree_depth + 1, embed_dim, padding_idx=0) for _ in range(spatial_dim)]
)
# transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=4 * embed_dim,
dropout=0.0,
activation='gelu',
)
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=num_layers,
norm=nn.LayerNorm(embed_dim),
)
# final linear layer
self.head = nn.Linear(embed_dim, num_vocab + 1, bias=False)
def forward(self, value, depth, pos):
"""
Expect input as shape:
value: (S, N)
depth: (S, N)
pos: (S, N, A)
shapes:
S: sequence length
N: batch size
E: embedding dimension
A: spatial dimension
"""
# embeddings
h = self.token_embedding(value) # [S, N, E]
h = h + self.depth_embedding(depth) # [S, N, E]
for axis, spatial_embedding in enumerate(self.spatial_embeddings):
h = h + spatial_embedding(pos[:, :, axis]) # [S, N, E]
# prepend start of sequence token
seq_len, batch = value.shape # [S, N]
sos = torch.ones(1, batch, self.embed_dim, device=value.device) * self.sos # [1, N, E]
h = torch.cat([sos, h[:-1, :, :]], axis=0) # [S, N, E]
# create attention mask
if self.attention == "basic_ancestor":
attn_mask = ancestor_mask(value, h.device)
attn_mask = torch.repeat_interleave(attn_mask, self.num_heads, dim=0)
else:
attn_mask = look_ahead_mask(seq_len, device=h.device)
# transformer encoder
h = self.transformer_encoder(
src=h,
mask=attn_mask,
src_key_padding_mask=padding_mask(value, device=h.device),
) # [S, N, E]
# return logits
return self.head(h)
def get_attn_weights(self):
return self.transformer_encoder._attention_weights
def get_attn_activations(self):
return self.transformer_encoder._attention_activations
import torch
import torch.nn as nn
import torch.nn.functional as F
from fast_transformers.builders import TransformerEncoderBuilder
from fast_transformers.masking import TriangularCausalMask, FullMask
from fast_transformers.feature_maps import Favor
_attention_map = {
'fast_full': 'full',
'fast_linear': 'causal-linear',
'fast_local': 'local',
'fast_reformer': 'reformer',
'fast_favor': 'causal-linear',
'fast_performer': 'causal-linear', # legacy
}
_feature_map = {
'fast_full': None,
'fast_linear': None,
'fast_local': None,
'fast_reformer': None,
'fast_favor': Favor.factory(),
'fast_performer': Favor.factory(), # legacy
}
class FastTransformerModule(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
num_layers,
num_positions,
num_vocab,
spatial_dim,
tree_depth,
attention,
):
super(FastTransformerModule, self).__init__()
self.attention_type = _attention_map[attention]
self.embed_dim = embed_dim
self.num_vocab = num_vocab
self.spatial_dim = spatial_dim
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
# embeddings
self.token_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0)
self.depth_embedding = nn.Embedding(tree_depth + 1, embed_dim, padding_idx=0)
self.spatial_embeddings = nn.ModuleList(
[nn.Embedding(2**tree_depth + 1, embed_dim, padding_idx=0) for _ in range(spatial_dim)]
)
# transformer encoder
kwargs = {
'attention_type': _attention_map[attention],
'local_context': 512,
'n_layers': num_layers,
'n_heads': num_heads,
'feed_forward_dimensions': embed_dim * 4,
'query_dimensions': embed_dim // num_heads,
'value_dimensions': embed_dim // num_heads,
'dropout': 0.0,
'attention_dropout': 0.0,
'activation': "gelu",
'feature_map': _feature_map[attention],
}
self.transformer_encoder = TransformerEncoderBuilder.from_kwargs(**kwargs).get()
# final linear layer
self.head = nn.Linear(embed_dim, num_vocab + 1, bias=False)
def forward(self, value, depth, pos):
"""
Expect input as shape:
value: (N, S)
depth: (N, S)
pos: (N, S, A)
shapes:
N: batch size
S: sequence length
E: embedding dimension
A: spatial dimension
"""
batch, seq_len = value.shape # [N, S]
if self.attention_type == "reformer":
pad_len = 128 - (seq_len % 128)
value = F.pad(input=value, pad=(0, pad_len))
depth = F.pad(input=depth, pad=(0, pad_len))
pos = F.pad(input=pos, pad=(0, pad_len))
# triangular causal and padding masks
causal_mask = TriangularCausalMask(value.shape[1], device=value.device) # [S, S]
padding_mask = FullMask(mask=value != 0, device=value.device) # [N, S]
# embeddings
x = self.token_embedding(value) # [N, S, E]
x = x + self.depth_embedding(depth) # [N, S, E]
for axis, spatial_embedding in enumerate(self.spatial_embeddings):
x = x + spatial_embedding(pos[:, :, axis]) # [N, S, E]
# prepend start of sequence token
sos = torch.ones(batch, 1, self.embed_dim, device=value.device) * self.sos # [N, 1, E]
x = torch.cat([sos, x[:, :-1, :]], axis=1) # [N, S, E]
# transformer encoder
x = self.transformer_encoder(
x=x,
attn_mask=causal_mask,
length_mask=padding_mask,
) # [N, S, E]
# return logits
return self.head(x)[:, :seq_len]
import torch
import torch.nn as nn
from linear_attention_transformer import LinearAttentionTransformer
class LinearTransformerModule(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
num_layers,
num_positions,
num_vocab,
spatial_dim,
tree_depth,
attention,
):
super(LinearTransformerModule, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_vocab = num_vocab
self.spatial_dim = spatial_dim
self.attention = attention
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
# embeddings
self.token_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0)
self.depth_embedding = nn.Embedding(tree_depth + 1, embed_dim, padding_idx=0)
self.spatial_embeddings = nn.ModuleList(
[nn.Embedding(2**tree_depth + 1, embed_dim, padding_idx=0) for _ in range(spatial_dim)]
)
# performer encoder
self.transformer_encoder = LinearAttentionTransformer(
dim=embed_dim,
depth=num_layers,
heads=num_heads,
max_seq_len=num_positions,
)
# final linear layer
self.head = nn.Linear(embed_dim, num_vocab + 1, bias=False)
def forward(self, value, depth, pos):
"""
Expect input as shape:
value: (N, S)
depth: (N, S)
pos: (N, S, A)
shapes:
S: sequence length
N: batch size
E: embedding dimension
A: spatial dimension
"""
batch, seq_len = value.shape # [N, S]
# embeddings
x = self.token_embedding(value) # [N, S, E]
x = x + self.depth_embedding(depth) # [N, S, E]
for axis, spatial_embedding in enumerate(self.spatial_embeddings):
x = x + spatial_embedding(pos[:, :, axis]) # [N, S, E]
# prepend start of sequence token
sos = torch.ones(batch, 1, self.embed_dim, device=value.device) * self.sos # [N, 1, E]
x = torch.cat([sos, x[:, :-1, :]], axis=1) # [N, S, E]
# transformer encoder TODO: pass mask, to mask out padding in batched inputs (n > 1)
x = self.transformer_encoder(x) # [N, S, E]
# return logits
return self.head(x)
import torch
import torch.nn as nn
from performer_pytorch import Performer
class PerformerModule(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
num_layers,
num_positions,
num_vocab,
spatial_dim,
tree_depth,
attention,
):
super(PerformerModule, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_vocab = num_vocab
self.spatial_dim = spatial_dim
self.attention = attention
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
# embeddings
self.token_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0)
self.depth_embedding = nn.Embedding(tree_depth + 1, embed_dim, padding_idx=0)
self.spatial_embeddings = nn.ModuleList(
[nn.Embedding(2**tree_depth + 1, embed_dim, padding_idx=0) for _ in range(spatial_dim)]
)
# performer encoder
self.transformer_encoder = Performer(
dim=embed_dim,
depth=num_layers,
heads=num_heads,
causal=True,
)
# final linear layer
self.head = nn.Linear(embed_dim, num_vocab + 1, bias=False)
def forward(self, value, depth, pos):
"""
Expect input as shape:
value: (N, S)
depth: (N, S)
pos: (N, S, A)
shapes:
S: sequence length
N: batch size
E: embedding dimension
A: spatial dimension
"""
batch, seq_len = value.shape # [N, S]
# embeddings
x = self.token_embedding(value) # [N, S, E]
x = x + self.depth_embedding(depth) # [N, S, E]
for axis, spatial_embedding in enumerate(self.spatial_embeddings):
x = x + spatial_embedding(pos[:, :, axis]) # [N, S, E]
# prepend start of sequence token
sos = torch.ones(batch, 1, self.embed_dim, device=value.device) * self.sos # [N, 1, E]
x = torch.cat([sos, x[:, :-1, :]], axis=1) # [N, S, E]
# transformer encoder TODO: pass mask, to mask out padding in batched inputs (n > 1)
x = self.transformer_encoder(x) # [N, S, E]
# return logits
return self.head(x)
import torch
import torch.nn as nn
import torch.nn.functional as F
from reformer_pytorch import Reformer
class ReformerModule(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
num_layers,
num_positions,
num_vocab,
spatial_dim,
tree_depth,
attention,
):
super(ReformerModule, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_vocab = num_vocab
self.spatial_dim = spatial_dim
self.attention = attention
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
# embeddings
self.token_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0)
self.depth_embedding = nn.Embedding(tree_depth + 1, embed_dim, padding_idx=0)
self.spatial_embeddings = nn.ModuleList(
[nn.Embedding(2**tree_depth + 1, embed_dim, padding_idx=0) for _ in range(spatial_dim)]
)
# reformer encoder
self.transformer_encoder = Reformer(
dim=embed_dim,
depth=num_layers,
max_seq_len=num_positions,
heads=num_heads,
dim_head=None,
bucket_size=64,
n_hashes=8,
ff_chunks=100,
attn_chunks=None,
causal=True,
)
# final linear layer
self.head = nn.Linear(embed_dim, num_vocab + 1, bias=False)
def forward(self, value, depth, pos):
"""
Expect input as shape:
value: (N, S)
depth: (N, S)
pos: (N, S, A)
shapes:
S: sequence length
N: batch size
E: embedding dimension
A: spatial dimension
"""
batch, seq_len = value.shape # [N, S]
# pad input - Sequence length needs to be divisible by target bucket size x 2
pad_len = 128 - (seq_len % 128)
value = F.pad(input=value, pad=(0, pad_len))
depth = F.pad(input=depth, pad=(0, pad_len))
pos = F.pad(input=pos, pad=(0, pad_len))
# embeddings
x = self.token_embedding(value) # [N, S, E]
x = x + self.depth_embedding(depth) # [N, S, E]
for axis, spatial_embedding in enumerate(self.spatial_embeddings):
x = x + spatial_embedding(pos[:, :, axis]) # [N, S, E]
# prepend start of sequence token
sos = torch.ones(batch, 1, self.embed_dim, device=value.device) * self.sos # [N, 1, E]
x = torch.cat([sos, x[:, :-1, :]], axis=1) # [N, S, E]
# transformer encoder TODO: pass mask, to mask out padding in batched inputs (n > 1)
x = self.transformer_encoder(x) # [N, S, E]
# return logits
return self.head(x)[:, :seq_len]
import torch
import torch.nn as nn
from routing_transformer import RoutingTransformer
class RoutingTransformerModule(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
num_layers,
num_positions,
num_vocab,
spatial_dim,
tree_depth,
attention,
):
super(RoutingTransformerModule, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_vocab = num_vocab
self.spatial_dim = spatial_dim
self.attention = attention
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
# embeddings
self.token_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0)
self.depth_embedding = nn.Embedding(tree_depth + 1, embed_dim, padding_idx=0)
self.spatial_embeddings = nn.ModuleList(
[nn.Embedding(2**tree_depth + 1, embed_dim, padding_idx=0) for _ in range(spatial_dim)]
)
# routing transformer encoder
self.transformer_encoder = RoutingTransformer(
dim=embed_dim,
depth=num_layers,
max_seq_len=num_positions,
heads=num_heads,
dim_head=None,
window_size=64,
local_attn_window_size=256,
local_attn_radius_blocks=1,
causal=True,
)
# final linear layer
self.head = nn.Linear(embed_dim, num_vocab + 1, bias=False)
def forward(self, value, depth, pos):
"""
Expect input as shape:
value: (N, S)
depth: (N, S)
pos: (N, S, A)
shapes:
S: sequence length
N: batch size
E: embedding dimension
A: spatial dimension
"""
batch, seq_len = value.shape # [N, S]
# triangular causal and padding masks
padding_mask = value != 0 # [N, S]
# embeddings
x = self.token_embedding(value) # [N, S, E]
x = x + self.depth_embedding(depth) # [N, S, E]
for axis, spatial_embedding in enumerate(self.spatial_embeddings):
x = x + spatial_embedding(pos[:, :, axis]) # [N, S, E]
# prepend start of sequence token
sos = torch.ones(batch, 1, self.embed_dim, device=value.device) * self.sos # [N, 1, E]
x = torch.cat([sos, x[:, :-1, :]], axis=1) # [N, S, E]
# transformer encoder TODO: pass mask, to mask out padding in batched inputs (n > 1)
x, aux_loss = self.transformer_encoder(x, input_mask=padding_mask) # [N, S, E]
# return logits
return self.head(x)
import torch
import torch.nn as nn
import torch.nn.functional as F
from sinkhorn_transformer import SinkhornTransformer
class SinkhornTransformerModule(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
num_layers,
num_positions,
num_vocab,
spatial_dim,
tree_depth,
attention,
):
super(SinkhornTransformerModule, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_vocab = num_vocab
self.spatial_dim = spatial_dim
self.attention = attention
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
# embeddings
self.token_embedding = nn.Embedding(num_vocab + 1, embed_dim, padding_idx=0)
self.depth_embedding = nn.Embedding(tree_depth + 1, embed_dim, padding_idx=0)
self.spatial_embeddings = nn.ModuleList(
[nn.Embedding(2**tree_depth + 1, embed_dim, padding_idx=0) for _ in range(spatial_dim)]
)
# performer encoder
self.transformer_encoder = SinkhornTransformer(
dim=embed_dim,
depth=num_layers,
max_seq_len=num_positions,
heads=num_heads,
dim_head=None,
bucket_size=64,
causal=True,
)
# final linear layer
self.head = nn.Linear(embed_dim, num_vocab + 1, bias=False)
def forward(self, value, depth, pos):
"""
Expect input as shape:
value: (N, S)
depth: (N, S)
pos: (N, S, A)
shapes:
S: sequence length
N: batch size
E: embedding dimension
A: spatial dimension
"""
batch, seq_len = value.shape # [N, S]
# pad input - Sequence length needs to be divisible by bucket size
pad_len = 64 - (seq_len % 64)
value = F.pad(input=value, pad=(0, pad_len))
depth = F.pad(input=depth, pad=(0, pad_len))
pos = F.pad(input=pos, pad=(0, pad_len))
# embeddings
x = self.token_embedding(value) # [N, S, E]
x = x + self.depth_embedding(depth) # [N, S, E]
for axis, spatial_embedding in enumerate(self.spatial_embeddings):
x = x + spatial_embedding(pos[:, :, axis]) # [N, S, E]
# prepend start of sequence token
sos = torch.ones(batch, 1, self.embed_dim, device=value.device) * self.sos # [N, 1, E]
x = torch.cat([sos, x[:, :-1, :]], axis=1) # [N, S, E]
# transformer encoder TODO: pass mask, to mask out padding in batched inputs (n > 1)
x = self.transformer_encoder(x) # [N, S, E]
# return logits
return self.head(x)[:, :seq_len]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment