Skip to content
Snippets Groups Projects
Commit 68aed754 authored by Gregor Kobsik's avatar Gregor Kobsik
Browse files

include `pytorch-fast-transformers` as an optional module

parent 467ca569
No related branches found
No related tags found
No related merge requests found
from .autoencoder import Autoencoder
from .transformer import Transformer
from .pytorch_transformer import PytorchTransformer
from .fast_transformer import FastTransformer
from .fast_recurrent_transformer import FastRecurrentTransformer
def create_architecture(
......@@ -61,8 +59,12 @@ def create_architecture(
elif architecture == "pytorch":
return PytorchTransformer(**kwargs)
elif architecture == "fast":
# include `pytorch-fast-transformers` as an optional module
from .fast_transformer import FastTransformer
return FastTransformer(**kwargs)
elif architecture == "fast-recurrent":
# include `pytorch-fast-transformers` as an optional module
from .fast_recurrent_transformer import FastRecurrentTransformer
return FastRecurrentTransformer(**kwargs)
else:
raise ValueError(f"ERROR: {attention}_{architecture} transformer architecture not implemented.")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment