Skip to content
Snippets Groups Projects
Commit 888c3d61 authored by Gregor Kobsik's avatar Gregor Kobsik
Browse files

fixed logging of loss while training

parent ba2a7bde
No related branches found
No related tags found
No related merge requests found
......@@ -72,7 +72,7 @@ class ShapeTransformer(pl.LightningModule):
logits = self.model(x)
loss = self.criterion(logits.view(-1, logits.size(-1)), x.view(-1))
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('train_loss', loss, on_step=False, on_epoch=False, prog_bar=True)
return {'loss': loss}
def validation_step(self, batch, batch_idx):
......@@ -82,13 +82,12 @@ class ShapeTransformer(pl.LightningModule):
logits = self.model(x)
loss = self.criterion(logits.view(-1, logits.size(-1)), x.view(-1))
self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
self.log('val_loss', loss, on_step=False, on_epoch=False, prog_bar=True)
return {'val_loss': loss}
def validation_epoch_end(self, outs):
avg_loss = torch.stack([x["val_loss"] for x in outs]).mean()
self.log('avg_loss', avg_loss, on_epoch=True, prog_bar=True)
return {'val_loss': avg_loss}
self.log('val_loss', avg_loss, on_step=False, on_epoch=True, prog_bar=True)
def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment