Skip to content
Snippets Groups Projects
Commit 1f7c86a6 authored by Gregor Kobsik's avatar Gregor Kobsik
Browse files

incorporate `pytorch` and `fast` architecture into sampling pipeline

parent 575f0e5b
No related branches found
No related tags found
No related merge requests found
......@@ -114,3 +114,7 @@ class FastTransformer(nn.Module):
# return logits
return self.head[0](output_seq, *seq) # [N, L, V]
def compute_logits(self, seq_layer, memory, idx, cls):
""" Alias for 'forward' to make this module compatible to old sampling pipeline. """
return self.forward([seq_layer], cls)
......@@ -117,3 +117,7 @@ class PytorchTransformer(nn.Module):
# return logits
return self.head[0](output_seq, *seq) # [N, L, V]
def compute_logits(self, seq_layer, memory, idx, cls):
""" Alias for 'forward' to make this module compatible to old sampling pipeline. """
return self.forward([seq_layer], cls)
......@@ -36,7 +36,7 @@ def create_sampler(
if architecture == "autoencoder":
return AutoencoderSampler(**kwargs)
elif architecture == "encoder_only":
elif architecture in ("encoder_only", "fast", "pytorch"):
return EncoderOnlySampler(**kwargs)
elif architecture == "encoder_decoder":
return EncoderDecoderSampler(**kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment