Skip to content
Snippets Groups Projects
Commit a62ca879 authored by Gregor Kobsik's avatar Gregor Kobsik
Browse files

added main.py as command line parser, which can call either the training or...

added main.py as command line parser, which can call either the training or the testing routine of the model.
parent f0cf44ce
No related branches found
No related tags found
No related merge requests found
from executable.train import train
from executable.test import test
__all__ = [
"train",
"test",
]
import yaml
import pytorch_lightning as pl
from utils.data import dataloaders
from models import ShapeTransformer
def test(args):
with open(args.config, "rb") as f:
config = yaml.safe_load(f)
trainer = pl.Trainer(gpus=config['gpus'])
_, _, test_dl = dataloaders(config['dataset'], config['batch_size'])
model = ShapeTransformer.load_from_checkpoint(args.checkpoint)
trainer.test(model, test_dataloaders=test_dl)
import yaml
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from models import ShapeTransformer
from utils.data import dataloaders
def train(args):
with open(args.config, "rb") as f:
config = yaml.safe_load(f)
name = f"{config['name']}"
if args.pretrained is not None:
model = ShapeTransformer.load_from_checkpoint(args.pretrained)
model.learning_rate = config['learning_rate']
else:
model = ShapeTransformer(
embed_dim=config['embed_dim'],
num_heads=config['num_heads'],
num_layers=config['num_layers'],
num_positions=config['num_positions'],
num_vocab=config['num_vocab'],
learning_rate=config['learning_rate'],
steps=config['steps'],
warmup_steps=config['warmup_steps'],
)
train_dl, valid_dl, _ = dataloaders(config['dataset'], config['batch_size'])
logger = pl_loggers.TensorBoardLogger("logs", name=name)
checkpoint = pl.callbacks.ModelCheckpoint(monitor="val_loss", save_last=True)
trainer = pl.Trainer(
max_steps=config['steps'],
gpus=config['gpus'],
precision=config['precision'],
accumulate_grad_batches=config['accumulate_grad_batches'],
checkpoint_callback=checkpoint,
logger=logger,
)
trainer.fit(model, train_dl, valid_dl)
main.py 0 → 100644
from argparse import ArgumentParser
from executable import train, test, sample
if __name__ == "__main__":
""" Parses the console input and calls one of the executable functions.
Func:
train: Creates a new model with random or pretrained weights and trains it according to the config file.
test: Tests the loss of the given checkpoint on the test data set.
"""
parser = ArgumentParser()
subparsers = parser.add_subparsers()
# TRAINING
parser_train = subparsers.add_parser("train")
parser_train.add_argument("--pretrained", type=str, default=None)
parser_train.add_argument("config", type=str)
parser_train.set_defaults(func=train)
# TESTING
parser_test = subparsers.add_parser("test")
parser_test.add_argument("checkpoint", type=str)
parser_test.add_argument("config", type=str)
parser_test.set_defaults(func=test)
args = parser.parse_args()
args.func(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment