Skip to content
Snippets Groups Projects
Commit ff59ff02 authored by Gregor Kobsik's avatar Gregor Kobsik
Browse files

added callbacks to track weights, biases and the gradient of the model

parent e4097bd0
Branches
Tags
No related merge requests found
from callbacks.weights_and_biases_logger import WeightsAndBiasesLogger
from callbacks.tracked_gradient_output import TrackedGradientOutput
__all__ = [
"WeightsAndBiasesLogger",
"TrackedGradientOutput",
]
import pytorch_lightning as pl
from pytorch_lightning import Callback
class TrackedGradientOutput(Callback):
def __init__(self, subcategory='gradients/', global_only=False):
super().__init__()
self.subcategory = subcategory
self.global_only = global_only
def _remove_per_weight_norms(self, func):
def f(*args):
norms = func(*args)
norms = dict(filter(lambda elem: '_total' in elem[0], norms.items()))
return norms
return f
def _add_category(self, func, subcategory):
def f(*args):
norms = func(*args)
norms = {subcategory + k: v for k, v in norms.items()}
return norms
return f
def on_train_start(self, trainer, pl_module):
if self.global_only:
pl.core.grads.GradInformation.grad_norm = self._remove_per_weight_norms(
pl.core.grads.GradInformation.grad_norm
)
if self.subcategory is not None:
pl.core.grads.GradInformation.grad_norm = self._add_category(
pl.core.grads.GradInformation.grad_norm, self.subcategory
)
from pytorch_lightning import Callback
class WeightsAndBiasesLogger(Callback):
def __init__(self, log_every_n_epoch=5):
super().__init__()
self.log_every_n_epoch = log_every_n_epoch
assert log_every_n_epoch != 0
def log_weights_and_biases(self, trainer, pl_module):
# iterating through all parameters
for name, params in pl_module.named_parameters():
# TODO: log weights and biases in separate subcategories
if 'weight' in name:
trainer.logger.experiment.add_histogram('weights/' + name, params, trainer.current_epoch)
elif 'bias' in name:
trainer.logger.experiment.add_histogram('biases/' + name, params, trainer.current_epoch)
else:
trainer.logger.experiment.add_histogram(name, params, trainer.current_epoch)
def on_epoch_end(self, trainer, pl_module):
if trainer.current_epoch % self.log_every_n_epoch == 0 and self.log_every_n_epoch > 0:
self.log_weights_and_biases(trainer, pl_module)
def on_fit_end(self, trainer, pl_module):
if trainer.current_epoch % self.log_every_n_epoch != 0 and self.log_every_n_epoch > 0:
self.log_weights_and_biases(trainer, pl_module)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment