From a155f85cf9b07c730e50a943e7d1c4e821d9324b Mon Sep 17 00:00:00 2001 From: Gregor Kobsik <gregor.kobsik@rwth-aachen.de> Date: Thu, 8 Jul 2021 14:43:47 +0200 Subject: [PATCH] refactor EMD, rename 'transformers'->'architecture' concatinated encoder/decoder in the EMD transformer into one process function --- factories/transformer_factory.py | 2 +- modules/architecture/__init__.py | 11 + .../autoencoder.py | 0 .../encoder_decoder.py | 0 modules/architecture/encoder_multi_decoder.py | 169 ++++++++++++++++ .../encoder_only.py | 0 modules/transformer/__init__.py | 11 - modules/transformer/encoder_multi_decoder.py | 191 ------------------ 8 files changed, 181 insertions(+), 203 deletions(-) create mode 100644 modules/architecture/__init__.py rename modules/{transformer => architecture}/autoencoder.py (100%) rename modules/{transformer => architecture}/encoder_decoder.py (100%) create mode 100644 modules/architecture/encoder_multi_decoder.py rename modules/{transformer => architecture}/encoder_only.py (100%) delete mode 100644 modules/transformer/__init__.py delete mode 100644 modules/transformer/encoder_multi_decoder.py diff --git a/factories/transformer_factory.py b/factories/transformer_factory.py index 7e983de..fd8b110 100644 --- a/factories/transformer_factory.py +++ b/factories/transformer_factory.py @@ -1,4 +1,4 @@ -from modules.transformer import ( +from modules.architecture import ( Autoencoder, EncoderOnly, EncoderDecoder, diff --git a/modules/architecture/__init__.py b/modules/architecture/__init__.py new file mode 100644 index 0000000..8aa6d43 --- /dev/null +++ b/modules/architecture/__init__.py @@ -0,0 +1,11 @@ +from modules.architecture.autoencoder import Autoencoder +from modules.architecture.encoder_only import EncoderOnly +from modules.architecture.encoder_decoder import EncoderDecoder +from modules.architecture.encoder_multi_decoder import EncoderMultiDecoder + +__all__ = [ + "Autoencoder", + "EncoderOnly", + "EncoderDecoder", + "EncoderMultiDecoder", +] diff --git a/modules/transformer/autoencoder.py b/modules/architecture/autoencoder.py similarity index 100% rename from modules/transformer/autoencoder.py rename to modules/architecture/autoencoder.py diff --git a/modules/transformer/encoder_decoder.py b/modules/architecture/encoder_decoder.py similarity index 100% rename from modules/transformer/encoder_decoder.py rename to modules/architecture/encoder_decoder.py diff --git a/modules/architecture/encoder_multi_decoder.py b/modules/architecture/encoder_multi_decoder.py new file mode 100644 index 0000000..bc178bb --- /dev/null +++ b/modules/architecture/encoder_multi_decoder.py @@ -0,0 +1,169 @@ +import torch +import torch.nn as nn + +from masks import look_ahead_mask, full_mask + + +class EncoderMultiDecoder(nn.Module): + def __init__(self, embed_dim, num_heads, num_layers, num_positions, token_embedding, generative_head, **_): + """ Creates an instance of an encoder multi decoder transformer. + + It accepts different implementations of `token_embedding`s and `generative_head`s. The following abbrevations + are used to reference the size and the content of a dimension in used tensors. + + Shapes: + N: batch size + L: layer sequence length + M: memory length + E: embedding dimension + A: spatial dimension + V: vocabulary size + + Args: + embed_dim: Number of embedding dimensions used by the attention. + num_heads: Number of heads used by the attention. + num_layers: Number of layers for each the 'decoder' and 'encoder' part of the transformer. + num_positions: Maximal length of processed input tokens. You can pass longer sequences as input, but they + will be truncated before feeding into the transformer, but after the embedding. Thus longer sequences + can be accepted by a non-basic embedding and possibly compressed to stay within the limit. + token_embedding: Instance of an embedding layer, which embedds given sequences of tokens into an embedding + space, which is the direct input for the transformer layers. + generative_head: Instance of a head layer, which transforms the output of the transformer into logits. + """ + super(EncoderMultiDecoder, self).__init__() + + self.embed_dim = embed_dim # E + self.num_positions = num_positions + assert len(token_embedding) == len(generative_head), "Number of embeddings and heads is not equal." + num_decoders = len(generative_head) - 1 + + # token embedding + self.embedding = token_embedding + + # start of sequence token + self.sos = torch.nn.Parameter(torch.zeros(embed_dim)) + nn.init.normal_(self.sos) + + # transformer encoder layer + encoder_layer = nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=num_heads, + dim_feedforward=4 * embed_dim, + dropout=0.0, + activation='gelu', + ) + + # transformer decoder layer + decoder_layer = nn.TransformerDecoderLayer( + d_model=embed_dim, + nhead=num_heads, + dim_feedforward=4 * embed_dim, + dropout=0.0, + activation='gelu', + ) + + # encoder multi decoder transformer + self.emd_transformer = nn.ModuleList() + self.emd_transformer.append( + nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_layers, + norm=nn.LayerNorm(embed_dim), + ) + ) + self.emd_transformer.extend( + [ + nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_layers, + norm=nn.LayerNorm(embed_dim), + ) for _ in range(num_decoders) + ] + ) + + # generative head + self.head = generative_head + + def _prepend_sos_token(self, x): + """ Shifts given sequence one token to right and pads with start of sequence (sos) token. """ + batch_size = x.shape[0] + sos = torch.ones(batch_size, 1, self.embed_dim, device=x.device) * self.sos # [N, 1, E] + return torch.cat([sos, x[:, :-1]], axis=1) # [N, S, E] + + def _transpose(self, x): + """ Transposes the first and second dimension of the input tensor. """ + return torch.transpose(x, 0, 1) + + def process(self, seq, memory, padding_mask, layer_idx, is_final): + """ Performs computations in the decoder part of the transformer. + + It embeds the target token sequence into the embedding space of the decoder and creates an upper triangular + mask to allow only for autoregressive token access. + + Args: + seq: Token layer sequence in embedding space - [N, L, E] + memory: Output of the last transformer layer - [N, M, E]. + padding_mask: Value token layer sequence padding mask - [N, L]. + layer_idx: Defines which transformer layer should be used. + is_final: Defines if the current layer is final, e.g. if the transformer should be 'autoregressive'. + + Return: + The output of the last layer of the decoder in latent decoder space - [N, L, E]. + """ + # limit sequence length to max `num_position` + seq = seq[:, :self.num_positions] # [N, L, E] + padding_mask = padding_mask[:, :self.num_positions] # [N, L] + + # create attention mask + seq_len = seq.shape[1] + if is_final: # attention mask is autoregressive in the final layer + attn_mask = look_ahead_mask(seq_len, device=seq.device) # [L, L] + # shift sequence by one token to right to predict tokens autoregressively + seq = self._prepend_sos_token(seq) # [N, L, E] + else: # otherwise we allow access to all tokens + attn_mask = full_mask(seq_len, device=seq.device) # [L, L] + + # process one transformer layer + if layer_idx == 0: # encoder part of the transformer + out = self.emd_transformer[0]( + src=self._transpose(seq), # [L, N, E], pytorch expects the sequence dimension to be first + mask=attn_mask, # [L, L] + src_key_padding_mask=padding_mask, # [N, L] + ) # [S, N, E] + else: # decoder part of the transformer + out = self.emd_transformer[layer_idx]( + tgt=self._transpose(seq), # [L, N, E], pytorch expects the sequence dimension to be first + memory=self._transpose(memory), # [M, N, E] + tgt_mask=attn_mask, # [L, L] + tgt_key_padding_mask=padding_mask # [N, L] + ) # [T, N, E] + + return self._transpose(out) # [N, S/T, E] + + def forward(self, sequence): + """ Performs a full transformer pass of the input sequence through embedding, transformer and generative head. + + Args: + sequence: List containing input sequences, where each element is a tuple of (value, depth, position) + sequence layer for the transformer with the shape ([N, L], [N, L], [N, L, A]), respectively. + + Return: + Logits which describe the autoregressive likelihood of the next target token, with shape [N, T, V]. + """ + seq_len = len(sequence) + memory = None + + # process sequence layers individually + for idx, seq_layer in enumerate(sequence): + is_final = idx == seq_len - 1 + + # embed sequence tokens + emb = self.embedding[idx].source(*seq_layer) # [N, L, E] + seq_mask = self.embedding[idx].src_padding_mask(*seq_layer[:2]) # [N, L] + + # compute memory / process sequence + memory = self.process(emb, memory, seq_mask, idx, is_final) # [N, L, E] + + # return logits + if is_final: # compute only for final layer + return self.head[idx](memory, *seq_layer) # [N, T, V] diff --git a/modules/transformer/encoder_only.py b/modules/architecture/encoder_only.py similarity index 100% rename from modules/transformer/encoder_only.py rename to modules/architecture/encoder_only.py diff --git a/modules/transformer/__init__.py b/modules/transformer/__init__.py deleted file mode 100644 index 24fb325..0000000 --- a/modules/transformer/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from modules.transformer.autoencoder import Autoencoder -from modules.transformer.encoder_only import EncoderOnly -from modules.transformer.encoder_decoder import EncoderDecoder -from modules.transformer.encoder_multi_decoder import EncoderMultiDecoder - -__all__ = [ - "Autoencoder", - "EncoderOnly", - "EncoderDecoder", - "EncoderMultiDecoder", -] diff --git a/modules/transformer/encoder_multi_decoder.py b/modules/transformer/encoder_multi_decoder.py deleted file mode 100644 index 1bb0282..0000000 --- a/modules/transformer/encoder_multi_decoder.py +++ /dev/null @@ -1,191 +0,0 @@ -import torch -import torch.nn as nn - -from masks import look_ahead_mask, full_mask - - -class EncoderMultiDecoder(nn.Module): - def __init__(self, embed_dim, num_heads, num_layers, num_positions, token_embedding, generative_head, **_): - """ Creates an instance of an encoder multi decoder transformer. - - It accepts different implementations of `token_embedding`s and `generative_head`s. The following abbrevations - are used to reference the size and the content of a dimension in used tensors. - - Shapes: - N: batch size - S: source sequence length - T: target sequence length - E: embedding dimension - A: spatial dimension - V: vocabulary size - - Args: - embed_dim: Number of embedding dimensions used by the attention. - num_heads: Number of heads used by the attention. - num_layers: Number of layers for each the 'decoder' and 'encoder' part of the transformer. - num_positions: Maximal length of processed input tokens. You can pass longer sequences as input, but they - will be truncated before feeding into the transformer, but after the embedding. Thus longer sequences - can be accepted by a non-basic embedding and possibly compressed to stay within the limit. - token_embedding: Instance of an embedding layer, which embedds given sequences of tokens into an embedding - space, which is the direct input for the transformer layers. - generative_head: Instance of a head layer, which transforms the output of the transformer into logits. - """ - super(EncoderMultiDecoder, self).__init__() - - self.embed_dim = embed_dim # E - self.num_positions = num_positions - num_decoders = len(generative_head) - - # token embedding - self.embedding = token_embedding - - # start of sequence token - self.sos = torch.nn.Parameter(torch.zeros(embed_dim)) - nn.init.normal_(self.sos) - - # transformer encoder - encoder_layer = nn.TransformerEncoderLayer( - d_model=embed_dim, - nhead=num_heads, - dim_feedforward=4 * embed_dim, - dropout=0.0, - activation='gelu', - ) - self.transformer_encoder = nn.TransformerEncoder( - encoder_layer=encoder_layer, - num_layers=num_layers, - norm=nn.LayerNorm(embed_dim), - ) - - # transformer decoder - decoder_layer = nn.TransformerDecoderLayer( - d_model=embed_dim, - nhead=num_heads, - dim_feedforward=4 * embed_dim, - dropout=0.0, - activation='gelu', - ) - self.transformer_decoder = nn.ModuleList( - [ - nn.TransformerDecoder( - decoder_layer=decoder_layer, - num_layers=num_layers, - norm=nn.LayerNorm(embed_dim), - ) for _ in range(num_decoders) - ] - ) - - # generative head - self.head = generative_head - - def _prepend_sos_token(self, x): - """ Shifts given sequence one token to right and fills missing token with start of sequence token. """ - batch_size = x.shape[0] - sos = torch.ones(batch_size, 1, self.embed_dim, device=x.device) * self.sos # [N, 1, E] - return torch.cat([sos, x[:, :-1]], axis=1) # [N, S, E] - - def _transpose(self, x): - """ Transposes the first and second dimension of the input tensor. """ - return torch.transpose(x, 0, 1) - - def encode(self, value, depth, pos): - """ Performs computations in the encoder part of the transformer. - - It embedds the given token sequences into the embedding space, given the `token_embedding`. Next, it creates - a full mask therefore every token can access every other token. - - Args: - value: Value token sequence - [N, S]. - depth: Depth token sequence - [N, S]. - pos: Position token sequences with a single token for each spatial dimension of the data - [N, S, A]. - - Return: - The output of the last layer of the encoder in latent encoder space - [N, S, E]. - """ - # compute the embedding vector sequence for encoder input - src = self.embedding[0].source(value, depth, pos) # [N, S, E] - - # create padding mask - padding_mask = self.embedding[0].src_padding_mask(value, depth) # [N, S] - - # limit sequence length to max `num_position` - src = src[:, :self.num_positions] # [N, S, E] - padding_mask = padding_mask[:, :self.num_positions] # [N, S] - - # encoder part of the transformer - pytorch expects, the sequence dimension to be first. - out = self.transformer_encoder( - src=self._transpose(src), # [S, N, E] - mask=None, - src_key_padding_mask=padding_mask, # [N, S] - ) # [S, N, E] - return self._transpose(out) # [N, S, E] - - def decode(self, value, depth, pos, memory, decoder_idx, final): - """ Performs computations in the decoder part of the transformer. - - It embeds the target token sequence into the embedding space of the decoder and creates an upper triangular - mask to allow only for autoregressive token access. - - Args: - value: Target value token sequence - [N, T]. - depth: Target depth token sequence - [N, T]. - pos: Target position token sequences with a single token for each spatial dimension of the data - [N, T, A]. - memory: The output of the last encoder/decoder layer - [N, S/T, E]. - decoder_idx: Defines which decoder instance should be used. - final: Defines if the current input is final, e.g. if the decoder should be 'autoregressive'. - - Return: - The output of the last layer of the decoder in latent decoder space - [N, T, E]. - """ - # compute the embedding vector sequence for decoder input - tgt = self.embedding[decoder_idx + 1].target(value, depth, pos) # [N, T, E] - tgt = self._prepend_sos_token(tgt) # [N, T, E] - - # create autoregressive attention and padding masks - tgt_len = tgt.shape[1] - if final: - attn_mask = look_ahead_mask(tgt_len, device=tgt.device) # [T, T] - else: - attn_mask = full_mask(tgt_len, device=tgt.device) # [T, T] - padding_mask = self.embedding[decoder_idx + 1].tgt_padding_mask(value, depth) # [N, T] - - # limit sequence length to max `num_position` - tgt = tgt[:, :self.num_positions] # [N, T, E] - attn_mask = attn_mask[:self.num_positions, :self.num_positions] # [T, T] - padding_mask = padding_mask[:, :self.num_positions] # [N, T] - - # decoder part of the transformer - pytorch expects, the sequence dimension to be first. - out = self.transformer_decoder[decoder_idx]( - tgt=self._transpose(tgt), # [T, N, E] - memory=self._transpose(memory), # [S, N, E] - tgt_mask=attn_mask, # [T, T] - tgt_key_padding_mask=padding_mask # [N, T] - ) # [T, N, E] - return self._transpose(out) # [N, T, E] - - def forward(self, sequence): - """ Performs a full transformer pass of the input sequence through embedding, transformer and generative head. - - Args: - sequence: Tuple containing input sequences as a tuple of (encoder_sequence, decoder_sequence), where each - of the elements is another tuple of (value, depth, position) sequence inputs for the encoder and decoder - with the shape ([N, S/T], [N, S/T], [N, S/T, A]), respectively. - - Return: - Logits which describe the autoregressive likelihood of the next target token, with shape [N, T, V]. - """ - seq_enc, seq_dec = sequence - - # process encoder - memory = self.encode(*seq_enc) # [N, S, E] - - # process every other layer separatly - for idx, seq_layer in enumerate(seq_dec): - is_final = idx == len(seq_dec) - 1 - - # process decoder - z = self.decode(*seq_layer, memory, idx, is_final) # [N, T*, E] - - # return logits in final layer - if is_final: - return self.head[idx](z, *seq_dec[-1]) # [N, T, V] -- GitLab