Skip to content
Snippets Groups Projects
Commit c9a3c74e authored by Gregor Kobsik's avatar Gregor Kobsik
Browse files

fix naming 'sliding_window'

parent 985546c4
Branches
No related tags found
No related merge requests found
......@@ -17,7 +17,7 @@ def create_data_collate(architecture, embeddings, resolution):
"""
if architecture == "autoencoder":
return AutoencoderCollate(embeddings)
if architecture in ("encoder_only", 'pytorch', 'fast', 'fast-recurrent', 'sliding-window'):
if architecture in ("encoder_only", 'pytorch', 'fast', 'fast-recurrent', 'fast_recurrent', 'sliding_window'):
return EncoderOnlyCollate()
if architecture == "encoder_decoder":
return EncoderDecoderCollate(embeddings)
......
......@@ -62,13 +62,13 @@ def create_architecture(
return Transformer(**kwargs, num_decoders=len(token_embedding) - 1)
elif architecture == "pytorch":
return PytorchTransformer(**kwargs)
elif architecture == "sliding-window":
elif architecture == "sliding_window":
return SlidingWindowTransformer(**kwargs)
elif architecture == "fast":
# include `pytorch-fast-transformers` as an optional module
from .fast_transformer import FastTransformer
return FastTransformer(**kwargs)
elif architecture == "fast-recurrent":
elif architecture in ("fast-recurrent", "fast_recurrent"):
# include `pytorch-fast-transformers` as an optional module
from .fast_recurrent_transformer import FastRecurrentTransformer
return FastRecurrentTransformer(**kwargs)
......
......@@ -13,7 +13,7 @@ from ..token_generator.recurrent import create_recurrent_token_generator
class RecurrentSampler:
def __init__(self, model, head, spatial_dim, max_resolution, position_encoding, device, **_):
""" Provides a basic implementation of the sampler for the 'fast-recurrent-transformer' architecture.
""" Provides a basic implementation of the sampler for the 'fast_recurrent-transformer' architecture.
Args:
model: Model which is used for sampling.
......
......@@ -37,9 +37,9 @@ def create_sampler(
if architecture == "autoencoder":
return AutoencoderSampler(**kwargs)
elif architecture in ("encoder_only", "fast", "pytorch", "sliding-window"):
elif architecture in ("encoder_only", "fast", "pytorch", "sliding_window"):
return EncoderOnlySampler(**kwargs)
elif architecture in ("fast-recurrent"):
elif architecture in ("fast-recurrent", "fast_recurrent"):
return RecurrentSampler(**kwargs)
elif architecture == "encoder_decoder":
return EncoderDecoderSampler(**kwargs)
......
......@@ -20,7 +20,7 @@ class ShapeSampler:
pl_module = ShapeTransformer.load_from_checkpoint(checkpoint_path)
if fast_recurrent is True and pl_module.hparams['architecture'] == 'fast':
print("Reload model as a recurrent implementation for a major improvement of inference time.")
pl_module = ShapeTransformer.load_from_checkpoint(checkpoint_path, architecture='fast-recurrent')
pl_module = ShapeTransformer.load_from_checkpoint(checkpoint_path, architecture='fast_recurrent')
pl_module.freeze()
# extract hyperparameters from the model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment