Skip to content
Snippets Groups Projects
Commit ba2a7bde authored by Gregor Kobsik's avatar Gregor Kobsik
Browse files

changed configuration files from steps into epochs, to get better controll over training.

parent 8568ffe8
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@ name: mnist_s
dataset: mnist
# training
steps: 10_000
epochs: 200
warmup_steps: 500
batch_size: 4
accumulate_grad_batches: 16
......
name: mnist_s
name: mnist_xs
dataset: mnist
# training
steps: 10_000
epochs: 100
warmup_steps: 500
batch_size: 16
accumulate_grad_batches: 4
......@@ -10,7 +10,7 @@ accumulate_grad_batches: 4
learning_rate: 0.001
# architecture
embed_dim: 256
embed_dim: 64
num_heads: 4
num_layers: 16
num_positions: 512
......
......@@ -2,7 +2,7 @@ name: mnist_xxs
dataset: mnist
# training
steps: 25_000
epochs: 30
warmup_steps: 500
batch_size: 64
accumulate_grad_batches: 1
......
......@@ -13,6 +13,19 @@ def train(args):
name = f"{config['name']}"
train_dl, valid_dl, _ = dataloaders(config['dataset'], config['batch_size'])
logger = pl_loggers.TensorBoardLogger("logs", name=name)
checkpoint = pl.callbacks.ModelCheckpoint(monitor="val_loss", save_last=True)
trainer = pl.Trainer(
max_epochs=config['epochs'],
gpus=config['gpus'],
precision=config['precision'],
accumulate_grad_batches=config['accumulate_grad_batches'],
checkpoint_callback=checkpoint,
logger=logger,
)
if args.pretrained is not None:
model = ShapeTransformer.load_from_checkpoint(args.pretrained)
model.learning_rate = config['learning_rate']
......@@ -24,21 +37,8 @@ def train(args):
num_positions=config['num_positions'],
num_vocab=config['num_vocab'],
learning_rate=config['learning_rate'],
steps=config['steps'],
steps=len(train_dl) * config['epochs'] / config['batch_size'],
warmup_steps=config['warmup_steps'],
)
train_dl, valid_dl, _ = dataloaders(config['dataset'], config['batch_size'])
logger = pl_loggers.TensorBoardLogger("logs", name=name)
checkpoint = pl.callbacks.ModelCheckpoint(monitor="val_loss", save_last=True)
trainer = pl.Trainer(
max_steps=config['steps'],
gpus=config['gpus'],
precision=config['precision'],
accumulate_grad_batches=config['accumulate_grad_batches'],
checkpoint_callback=checkpoint,
logger=logger,
)
trainer.fit(model, train_dl, valid_dl)
......@@ -49,7 +49,7 @@ class ShapeTransformer(pl.LightningModule):
parser.add_argument("--num_vocab", type=int, default=16)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--learning_rate", type=float, default=3e-3)
parser.add_argument("--steps", type=int, default=25_000)
parser.add_argument("--epochs", type=int, default=50)
return parser
def configure_optimizers(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment