diff --git a/data/collate/data_collate_factory.py b/data/collate/data_collate_factory.py index 999a1a78786041461866834a284250df48a69ec1..20507fcaac1164060a74aeb28409fd7ab124b451 100644 --- a/data/collate/data_collate_factory.py +++ b/data/collate/data_collate_factory.py @@ -17,7 +17,7 @@ def create_data_collate(architecture, embeddings, resolution): """ if architecture == "autoencoder": return AutoencoderCollate(embeddings) - if architecture in ("encoder_only", 'pytorch', 'fast', 'fast-recurrent', 'sliding-window'): + if architecture in ("encoder_only", 'pytorch', 'fast', 'fast-recurrent', 'fast_recurrent', 'sliding_window'): return EncoderOnlyCollate() if architecture == "encoder_decoder": return EncoderDecoderCollate(embeddings) diff --git a/modules/architecture/architecture_factory.py b/modules/architecture/architecture_factory.py index ed041d9a4b63c279130e64b7774f5dc2273096de..4f2b0a2f2c53745eb60f64676c5a777f8b148d87 100644 --- a/modules/architecture/architecture_factory.py +++ b/modules/architecture/architecture_factory.py @@ -62,13 +62,13 @@ def create_architecture( return Transformer(**kwargs, num_decoders=len(token_embedding) - 1) elif architecture == "pytorch": return PytorchTransformer(**kwargs) - elif architecture == "sliding-window": + elif architecture == "sliding_window": return SlidingWindowTransformer(**kwargs) elif architecture == "fast": # include `pytorch-fast-transformers` as an optional module from .fast_transformer import FastTransformer return FastTransformer(**kwargs) - elif architecture == "fast-recurrent": + elif architecture in ("fast-recurrent", "fast_recurrent"): # include `pytorch-fast-transformers` as an optional module from .fast_recurrent_transformer import FastRecurrentTransformer return FastRecurrentTransformer(**kwargs) diff --git a/sample/layer_sampler/recurrent_sampler.py b/sample/layer_sampler/recurrent_sampler.py index 0d2e1be09986b17ee3ebe4e5537b9aa8cd9b00bd..901ca742b7fbccdb43fd24ca4e78545db50d2d8d 100644 --- a/sample/layer_sampler/recurrent_sampler.py +++ b/sample/layer_sampler/recurrent_sampler.py @@ -13,7 +13,7 @@ from ..token_generator.recurrent import create_recurrent_token_generator class RecurrentSampler: def __init__(self, model, head, spatial_dim, max_resolution, position_encoding, device, **_): - """ Provides a basic implementation of the sampler for the 'fast-recurrent-transformer' architecture. + """ Provides a basic implementation of the sampler for the 'fast_recurrent-transformer' architecture. Args: model: Model which is used for sampling. diff --git a/sample/layer_sampler/sampler_factory.py b/sample/layer_sampler/sampler_factory.py index 6a383de75065db0392179f7fb30fa0d36722534c..0d37d9fbc8cb0993352b90eeea7a4459933ffaf7 100644 --- a/sample/layer_sampler/sampler_factory.py +++ b/sample/layer_sampler/sampler_factory.py @@ -37,9 +37,9 @@ def create_sampler( if architecture == "autoencoder": return AutoencoderSampler(**kwargs) - elif architecture in ("encoder_only", "fast", "pytorch", "sliding-window"): + elif architecture in ("encoder_only", "fast", "pytorch", "sliding_window"): return EncoderOnlySampler(**kwargs) - elif architecture in ("fast-recurrent"): + elif architecture in ("fast-recurrent", "fast_recurrent"): return RecurrentSampler(**kwargs) elif architecture == "encoder_decoder": return EncoderDecoderSampler(**kwargs) diff --git a/sample/shape_sampler.py b/sample/shape_sampler.py index 37c311f711e24e92a09dee8591c66f7a4fcfd487..4ca127e1a97af0a53949f55d39a403f8e3f69c35 100644 --- a/sample/shape_sampler.py +++ b/sample/shape_sampler.py @@ -20,7 +20,7 @@ class ShapeSampler: pl_module = ShapeTransformer.load_from_checkpoint(checkpoint_path) if fast_recurrent is True and pl_module.hparams['architecture'] == 'fast': print("Reload model as a recurrent implementation for a major improvement of inference time.") - pl_module = ShapeTransformer.load_from_checkpoint(checkpoint_path, architecture='fast-recurrent') + pl_module = ShapeTransformer.load_from_checkpoint(checkpoint_path, architecture='fast_recurrent') pl_module.freeze() # extract hyperparameters from the model