From a155f85cf9b07c730e50a943e7d1c4e821d9324b Mon Sep 17 00:00:00 2001
From: Gregor Kobsik <gregor.kobsik@rwth-aachen.de>
Date: Thu, 8 Jul 2021 14:43:47 +0200
Subject: [PATCH] refactor EMD, rename 'transformers'->'architecture'

concatinated encoder/decoder in the EMD transformer into one process function
---
 factories/transformer_factory.py              |   2 +-
 modules/architecture/__init__.py              |  11 +
 .../autoencoder.py                            |   0
 .../encoder_decoder.py                        |   0
 modules/architecture/encoder_multi_decoder.py | 169 ++++++++++++++++
 .../encoder_only.py                           |   0
 modules/transformer/__init__.py               |  11 -
 modules/transformer/encoder_multi_decoder.py  | 191 ------------------
 8 files changed, 181 insertions(+), 203 deletions(-)
 create mode 100644 modules/architecture/__init__.py
 rename modules/{transformer => architecture}/autoencoder.py (100%)
 rename modules/{transformer => architecture}/encoder_decoder.py (100%)
 create mode 100644 modules/architecture/encoder_multi_decoder.py
 rename modules/{transformer => architecture}/encoder_only.py (100%)
 delete mode 100644 modules/transformer/__init__.py
 delete mode 100644 modules/transformer/encoder_multi_decoder.py

diff --git a/factories/transformer_factory.py b/factories/transformer_factory.py
index 7e983de..fd8b110 100644
--- a/factories/transformer_factory.py
+++ b/factories/transformer_factory.py
@@ -1,4 +1,4 @@
-from modules.transformer import (
+from modules.architecture import (
     Autoencoder,
     EncoderOnly,
     EncoderDecoder,
diff --git a/modules/architecture/__init__.py b/modules/architecture/__init__.py
new file mode 100644
index 0000000..8aa6d43
--- /dev/null
+++ b/modules/architecture/__init__.py
@@ -0,0 +1,11 @@
+from modules.architecture.autoencoder import Autoencoder
+from modules.architecture.encoder_only import EncoderOnly
+from modules.architecture.encoder_decoder import EncoderDecoder
+from modules.architecture.encoder_multi_decoder import EncoderMultiDecoder
+
+__all__ = [
+    "Autoencoder",
+    "EncoderOnly",
+    "EncoderDecoder",
+    "EncoderMultiDecoder",
+]
diff --git a/modules/transformer/autoencoder.py b/modules/architecture/autoencoder.py
similarity index 100%
rename from modules/transformer/autoencoder.py
rename to modules/architecture/autoencoder.py
diff --git a/modules/transformer/encoder_decoder.py b/modules/architecture/encoder_decoder.py
similarity index 100%
rename from modules/transformer/encoder_decoder.py
rename to modules/architecture/encoder_decoder.py
diff --git a/modules/architecture/encoder_multi_decoder.py b/modules/architecture/encoder_multi_decoder.py
new file mode 100644
index 0000000..bc178bb
--- /dev/null
+++ b/modules/architecture/encoder_multi_decoder.py
@@ -0,0 +1,169 @@
+import torch
+import torch.nn as nn
+
+from masks import look_ahead_mask, full_mask
+
+
+class EncoderMultiDecoder(nn.Module):
+    def __init__(self, embed_dim, num_heads, num_layers, num_positions, token_embedding, generative_head, **_):
+        """ Creates an instance of an encoder multi decoder transformer.
+
+        It accepts different implementations of `token_embedding`s and `generative_head`s. The following abbrevations
+        are used to reference the size and the content of a dimension in used tensors.
+
+        Shapes:
+            N: batch size
+            L: layer sequence length
+            M: memory length
+            E: embedding dimension
+            A: spatial dimension
+            V: vocabulary size
+
+        Args:
+            embed_dim: Number of embedding dimensions used by the attention.
+            num_heads: Number of heads used by the attention.
+            num_layers: Number of layers for each the 'decoder' and 'encoder' part of the transformer.
+            num_positions: Maximal length of processed input tokens. You can pass longer sequences as input, but they
+                will be truncated before feeding into the transformer, but after the embedding. Thus longer sequences
+                can be accepted by a non-basic embedding and possibly compressed to stay within the limit.
+            token_embedding: Instance of an embedding layer, which embedds given sequences of tokens into an embedding
+                space, which is the direct input for the transformer layers.
+            generative_head: Instance of a head layer, which transforms the output of the transformer into logits.
+        """
+        super(EncoderMultiDecoder, self).__init__()
+
+        self.embed_dim = embed_dim  # E
+        self.num_positions = num_positions
+        assert len(token_embedding) == len(generative_head), "Number of embeddings and heads is not equal."
+        num_decoders = len(generative_head) - 1
+
+        # token embedding
+        self.embedding = token_embedding
+
+        # start of sequence token
+        self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
+        nn.init.normal_(self.sos)
+
+        # transformer encoder layer
+        encoder_layer = nn.TransformerEncoderLayer(
+            d_model=embed_dim,
+            nhead=num_heads,
+            dim_feedforward=4 * embed_dim,
+            dropout=0.0,
+            activation='gelu',
+        )
+
+        # transformer decoder layer
+        decoder_layer = nn.TransformerDecoderLayer(
+            d_model=embed_dim,
+            nhead=num_heads,
+            dim_feedforward=4 * embed_dim,
+            dropout=0.0,
+            activation='gelu',
+        )
+
+        # encoder multi decoder transformer
+        self.emd_transformer = nn.ModuleList()
+        self.emd_transformer.append(
+            nn.TransformerEncoder(
+                encoder_layer=encoder_layer,
+                num_layers=num_layers,
+                norm=nn.LayerNorm(embed_dim),
+            )
+        )
+        self.emd_transformer.extend(
+            [
+                nn.TransformerDecoder(
+                    decoder_layer=decoder_layer,
+                    num_layers=num_layers,
+                    norm=nn.LayerNorm(embed_dim),
+                ) for _ in range(num_decoders)
+            ]
+        )
+
+        # generative head
+        self.head = generative_head
+
+    def _prepend_sos_token(self, x):
+        """ Shifts given sequence one token to right and pads with start of sequence (sos) token. """
+        batch_size = x.shape[0]
+        sos = torch.ones(batch_size, 1, self.embed_dim, device=x.device) * self.sos  # [N, 1, E]
+        return torch.cat([sos, x[:, :-1]], axis=1)  # [N, S, E]
+
+    def _transpose(self, x):
+        """ Transposes the first and second dimension of the input tensor. """
+        return torch.transpose(x, 0, 1)
+
+    def process(self, seq, memory, padding_mask, layer_idx, is_final):
+        """ Performs computations in the decoder part of the transformer.
+
+        It embeds the target token sequence into the embedding space of the decoder and creates an upper triangular
+        mask to allow only for autoregressive token access.
+
+        Args:
+            seq: Token layer sequence in embedding space - [N, L, E]
+            memory: Output of the last transformer layer - [N, M, E].
+            padding_mask: Value token layer sequence padding mask - [N, L].
+            layer_idx: Defines which transformer layer should be used.
+            is_final: Defines if the current layer is final, e.g. if the transformer should be 'autoregressive'.
+
+        Return:
+            The output of the last layer of the decoder in latent decoder space - [N, L, E].
+        """
+        # limit sequence length to max `num_position`
+        seq = seq[:, :self.num_positions]  # [N, L, E]
+        padding_mask = padding_mask[:, :self.num_positions]  # [N, L]
+
+        # create attention mask
+        seq_len = seq.shape[1]
+        if is_final:  # attention mask is autoregressive in the final layer
+            attn_mask = look_ahead_mask(seq_len, device=seq.device)  # [L, L]
+            # shift sequence by one token to right to predict tokens autoregressively
+            seq = self._prepend_sos_token(seq)  # [N, L, E]
+        else:  # otherwise we allow access to all tokens
+            attn_mask = full_mask(seq_len, device=seq.device)  # [L, L]
+
+        # process one transformer layer
+        if layer_idx == 0:  # encoder part of the transformer
+            out = self.emd_transformer[0](
+                src=self._transpose(seq),  # [L, N, E], pytorch expects the sequence dimension to be first
+                mask=attn_mask,  # [L, L]
+                src_key_padding_mask=padding_mask,  # [N, L]
+            )  # [S, N, E]
+        else:  # decoder part of the transformer
+            out = self.emd_transformer[layer_idx](
+                tgt=self._transpose(seq),  # [L, N, E], pytorch expects the sequence dimension to be first
+                memory=self._transpose(memory),  # [M, N, E]
+                tgt_mask=attn_mask,  # [L, L]
+                tgt_key_padding_mask=padding_mask  # [N, L]
+            )  # [T, N, E]
+
+        return self._transpose(out)  # [N, S/T, E]
+
+    def forward(self, sequence):
+        """ Performs a full transformer pass of the input sequence through embedding, transformer and generative head.
+
+        Args:
+            sequence: List containing input sequences, where each element is a tuple of (value, depth, position)
+                sequence layer for the transformer with the shape ([N, L], [N, L], [N, L, A]), respectively.
+
+        Return:
+            Logits which describe the autoregressive likelihood of the next target token, with shape [N, T, V].
+        """
+        seq_len = len(sequence)
+        memory = None
+
+        # process sequence layers individually
+        for idx, seq_layer in enumerate(sequence):
+            is_final = idx == seq_len - 1
+
+            # embed sequence tokens
+            emb = self.embedding[idx].source(*seq_layer)  # [N, L, E]
+            seq_mask = self.embedding[idx].src_padding_mask(*seq_layer[:2])  # [N, L]
+
+            # compute memory / process sequence
+            memory = self.process(emb, memory, seq_mask, idx, is_final)  # [N, L, E]
+
+            # return logits
+            if is_final:  # compute only for final layer
+                return self.head[idx](memory, *seq_layer)  # [N, T, V]
diff --git a/modules/transformer/encoder_only.py b/modules/architecture/encoder_only.py
similarity index 100%
rename from modules/transformer/encoder_only.py
rename to modules/architecture/encoder_only.py
diff --git a/modules/transformer/__init__.py b/modules/transformer/__init__.py
deleted file mode 100644
index 24fb325..0000000
--- a/modules/transformer/__init__.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from modules.transformer.autoencoder import Autoencoder
-from modules.transformer.encoder_only import EncoderOnly
-from modules.transformer.encoder_decoder import EncoderDecoder
-from modules.transformer.encoder_multi_decoder import EncoderMultiDecoder
-
-__all__ = [
-    "Autoencoder",
-    "EncoderOnly",
-    "EncoderDecoder",
-    "EncoderMultiDecoder",
-]
diff --git a/modules/transformer/encoder_multi_decoder.py b/modules/transformer/encoder_multi_decoder.py
deleted file mode 100644
index 1bb0282..0000000
--- a/modules/transformer/encoder_multi_decoder.py
+++ /dev/null
@@ -1,191 +0,0 @@
-import torch
-import torch.nn as nn
-
-from masks import look_ahead_mask, full_mask
-
-
-class EncoderMultiDecoder(nn.Module):
-    def __init__(self, embed_dim, num_heads, num_layers, num_positions, token_embedding, generative_head, **_):
-        """ Creates an instance of an encoder multi decoder transformer.
-
-        It accepts different implementations of `token_embedding`s and `generative_head`s. The following abbrevations
-        are used to reference the size and the content of a dimension in used tensors.
-
-        Shapes:
-            N: batch size
-            S: source sequence length
-            T: target sequence length
-            E: embedding dimension
-            A: spatial dimension
-            V: vocabulary size
-
-        Args:
-            embed_dim: Number of embedding dimensions used by the attention.
-            num_heads: Number of heads used by the attention.
-            num_layers: Number of layers for each the 'decoder' and 'encoder' part of the transformer.
-            num_positions: Maximal length of processed input tokens. You can pass longer sequences as input, but they
-                will be truncated before feeding into the transformer, but after the embedding. Thus longer sequences
-                can be accepted by a non-basic embedding and possibly compressed to stay within the limit.
-            token_embedding: Instance of an embedding layer, which embedds given sequences of tokens into an embedding
-                space, which is the direct input for the transformer layers.
-            generative_head: Instance of a head layer, which transforms the output of the transformer into logits.
-        """
-        super(EncoderMultiDecoder, self).__init__()
-
-        self.embed_dim = embed_dim  # E
-        self.num_positions = num_positions
-        num_decoders = len(generative_head)
-
-        # token embedding
-        self.embedding = token_embedding
-
-        # start of sequence token
-        self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
-        nn.init.normal_(self.sos)
-
-        # transformer encoder
-        encoder_layer = nn.TransformerEncoderLayer(
-            d_model=embed_dim,
-            nhead=num_heads,
-            dim_feedforward=4 * embed_dim,
-            dropout=0.0,
-            activation='gelu',
-        )
-        self.transformer_encoder = nn.TransformerEncoder(
-            encoder_layer=encoder_layer,
-            num_layers=num_layers,
-            norm=nn.LayerNorm(embed_dim),
-        )
-
-        # transformer decoder
-        decoder_layer = nn.TransformerDecoderLayer(
-            d_model=embed_dim,
-            nhead=num_heads,
-            dim_feedforward=4 * embed_dim,
-            dropout=0.0,
-            activation='gelu',
-        )
-        self.transformer_decoder = nn.ModuleList(
-            [
-                nn.TransformerDecoder(
-                    decoder_layer=decoder_layer,
-                    num_layers=num_layers,
-                    norm=nn.LayerNorm(embed_dim),
-                ) for _ in range(num_decoders)
-            ]
-        )
-
-        # generative head
-        self.head = generative_head
-
-    def _prepend_sos_token(self, x):
-        """ Shifts given sequence one token to right and fills missing token with start of sequence token. """
-        batch_size = x.shape[0]
-        sos = torch.ones(batch_size, 1, self.embed_dim, device=x.device) * self.sos  # [N, 1, E]
-        return torch.cat([sos, x[:, :-1]], axis=1)  # [N, S, E]
-
-    def _transpose(self, x):
-        """ Transposes the first and second dimension of the input tensor. """
-        return torch.transpose(x, 0, 1)
-
-    def encode(self, value, depth, pos):
-        """ Performs computations in the encoder part of the transformer.
-
-        It embedds the given token sequences into the embedding space, given the `token_embedding`. Next, it creates
-        a full mask therefore every token can access every other token.
-
-        Args:
-            value: Value token sequence - [N, S].
-            depth: Depth token sequence - [N, S].
-            pos: Position token sequences with a single token for each spatial dimension of the data - [N, S, A].
-
-        Return:
-            The output of the last layer of the encoder in latent encoder space - [N, S, E].
-        """
-        # compute the embedding vector sequence for encoder input
-        src = self.embedding[0].source(value, depth, pos)  # [N, S, E]
-
-        # create padding mask
-        padding_mask = self.embedding[0].src_padding_mask(value, depth)  # [N, S]
-
-        # limit sequence length to max `num_position`
-        src = src[:, :self.num_positions]  # [N, S, E]
-        padding_mask = padding_mask[:, :self.num_positions]  # [N, S]
-
-        # encoder part of the transformer - pytorch expects, the sequence dimension to be first.
-        out = self.transformer_encoder(
-            src=self._transpose(src),  # [S, N, E]
-            mask=None,
-            src_key_padding_mask=padding_mask,  # [N, S]
-        )  # [S, N, E]
-        return self._transpose(out)  # [N, S, E]
-
-    def decode(self, value, depth, pos, memory, decoder_idx, final):
-        """ Performs computations in the decoder part of the transformer.
-
-        It embeds the target token sequence into the embedding space of the decoder and creates an upper triangular
-        mask to allow only for autoregressive token access.
-
-        Args:
-            value: Target value token sequence - [N, T].
-            depth: Target depth token sequence - [N, T].
-            pos: Target position token sequences with a single token for each spatial dimension of the data - [N, T, A].
-            memory: The output of the last encoder/decoder layer - [N, S/T, E].
-            decoder_idx: Defines which decoder instance should be used.
-            final: Defines if the current input is final, e.g. if the decoder should be 'autoregressive'.
-
-        Return:
-            The output of the last layer of the decoder in latent decoder space - [N, T, E].
-        """
-        # compute the embedding vector sequence for decoder input
-        tgt = self.embedding[decoder_idx + 1].target(value, depth, pos)  # [N, T, E]
-        tgt = self._prepend_sos_token(tgt)  # [N, T, E]
-
-        # create autoregressive attention and padding masks
-        tgt_len = tgt.shape[1]
-        if final:
-            attn_mask = look_ahead_mask(tgt_len, device=tgt.device)  # [T, T]
-        else:
-            attn_mask = full_mask(tgt_len, device=tgt.device)  # [T, T]
-        padding_mask = self.embedding[decoder_idx + 1].tgt_padding_mask(value, depth)  # [N, T]
-
-        # limit sequence length to max `num_position`
-        tgt = tgt[:, :self.num_positions]  # [N, T, E]
-        attn_mask = attn_mask[:self.num_positions, :self.num_positions]  # [T, T]
-        padding_mask = padding_mask[:, :self.num_positions]  # [N, T]
-
-        # decoder part of the transformer - pytorch expects, the sequence dimension to be first.
-        out = self.transformer_decoder[decoder_idx](
-            tgt=self._transpose(tgt),  # [T, N, E]
-            memory=self._transpose(memory),  # [S, N, E]
-            tgt_mask=attn_mask,  # [T, T]
-            tgt_key_padding_mask=padding_mask  # [N, T]
-        )  # [T, N, E]
-        return self._transpose(out)  # [N, T, E]
-
-    def forward(self, sequence):
-        """ Performs a full transformer pass of the input sequence through embedding, transformer and generative head.
-
-        Args:
-            sequence: Tuple containing input sequences as a tuple of (encoder_sequence, decoder_sequence), where each
-                of the elements is another tuple of (value, depth, position) sequence inputs for the encoder and decoder
-                with the shape ([N, S/T], [N, S/T], [N, S/T, A]), respectively.
-
-        Return:
-            Logits which describe the autoregressive likelihood of the next target token, with shape [N, T, V].
-        """
-        seq_enc, seq_dec = sequence
-
-        # process encoder
-        memory = self.encode(*seq_enc)  # [N, S, E]
-
-        # process every other layer separatly
-        for idx, seq_layer in enumerate(seq_dec):
-            is_final = idx == len(seq_dec) - 1
-
-            # process decoder
-            z = self.decode(*seq_layer, memory, idx, is_final)  # [N, T*, E]
-
-            # return logits in final layer
-            if is_final:
-                return self.head[idx](z, *seq_dec[-1])  # [N, T, V]
-- 
GitLab