Skip to content
Snippets Groups Projects
Commit 8cf95852 authored by Gregor Kobsik's avatar Gregor Kobsik
Browse files

Rework loading of 'fast-recurrent' sampler.

parent e85530c1
No related branches found
No related tags found
No related merge requests found
......@@ -12,15 +12,15 @@ class ShapeSampler:
Args:
checkpoint_path: Relative or absolute path to a checkpoint file ("*.ckpt") containing a trained model.
fast_recurrent: Changes the 'fast' architecture of the Transformer into an equivalent, but recurrent
formulation durring inference time.
formulation durring inference time, otherwise uses the standard full pass technique.
device: Selects the device on which the sampling should be performed. Either "cpu" or "cuda" (gpu-support)
available.
"""
# load and restore model from checkpoint
if fast_recurrent is True:
pl_module = ShapeTransformer.load_from_checkpoint(checkpoint_path, architecture='fast-recurrent')
else:
pl_module = ShapeTransformer.load_from_checkpoint(checkpoint_path)
if fast_recurrent is True and pl_module.hparams['architecture'] == 'fast':
print("Reload model as a recurrent implementation for a major improvement of inference time.")
pl_module = ShapeTransformer.load_from_checkpoint(checkpoint_path, architecture='fast-recurrent')
pl_module.freeze()
# extract hyperparameters from the model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment