Skip to content
Snippets Groups Projects
Commit ad18c097 authored by Gregor Kobsik's avatar Gregor Kobsik
Browse files

refactored transformer - compute memory/logits

 - added external functions to compute memory or logits for sampler
parent f1b5f801
No related branches found
No related tags found
No related merge requests found
......@@ -166,15 +166,24 @@ class Transformer(nn.Module):
# process sequence layers individually
for idx, seq_layer in enumerate(sequence):
is_final = idx == seq_len - 1
if idx < seq_len - 1: # intermediate layer
memory = self.compute_memory(seq_layer, memory, idx, False) # [N, L, E]
else: # only final layer
return self.compute_logits(seq_layer, memory, idx) # [N, T, V]
def compute_memory(self, seq_layer, memory, idx, is_final):
""" """
# embed sequence tokens
emb = self.embedding[idx](*seq_layer) # [N, L, E]
seq_mask = self.embedding[idx].padding_mask(*seq_layer) # [N, L]
# compute memory / process sequence
memory = self.process(emb, memory, seq_mask, idx, is_final) # [N, L, E]
return self.process(emb, memory, seq_mask, idx, is_final) # [N, L, E]
def compute_logits(self, seq_layer, memory, idx):
""" """
# compute memory
memory = self.compute_memory(seq_layer, memory, idx, True) # [N, L, E]
# return logits
if is_final: # compute only for final layer
return self.head[idx](memory, *seq_layer) # [N, T, V]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment