Source code for pyblaze.nn.callbacks.schedule

from .base import TrainingCallback, ValueTrainingCallback

[docs]class LearningRateScheduler(TrainingCallback): """ The learning rate scheduler may be used with a PyTorch learning rate scheduler. The callback is automatically triggered after the end of every iteration or epoch. """
[docs] def __init__(self, scheduler, metric=None, after_batch=False): """ Initializes a new learning rate scheduler for the given PyTorch scheduler. Parameters ---------- scheduler: torch.optim.lr_scheduler The PyTorch scheduler. metric: str, default: None The metric to pass to the scheduler, e.g. useful for reducing the learning rate as the validation loss pleateaus. Typically, it should only be used with :code:`after_batch` set to `False`. after_batch: bool, default: False Whether to call the scheduler after every batch or after every epoch. """ self.exec_after_batch = after_batch self.metric = metric self.scheduler = scheduler
def after_batch(self, metrics): if self.exec_after_batch: self._exec(metrics) def after_epoch(self, metrics): if not self.exec_after_batch: self._exec(metrics) def _exec(self, metrics): if self.metric is not None: self.scheduler.step(metrics[self.metric]) else: self.scheduler.step()
[docs]class ParameterScheduler(ValueTrainingCallback): """ The parameter scheduler is able to change the value of a variable over the course of the training. """
[docs] def __init__(self, initial, schedule, *args, **kwargs): r""" Initalizes a new scheduler for the given parameter. Parameters ---------- initial: object The initial value fo the parameter which should be modified over the course of the training. schedule: func (object, int, int, \**kwargs) -> object Function which should return the value of the parameter based on the current value of the parameter, the current epoch, and the iteration within the epoch. The function is called after every iteration (i.e. batch). It is further passed the arguments given to this initializer. args: variadic argument Additional arguments passed to the :code:`schedule` function. kwargs: keyword arguments Additional keyword arguments passed to the :code:`schedule` function. """ self.parameter = initial self.schedule = schedule self.args = args self.kwargs = kwargs self.epoch = None self.iteration = None
def read(self): return self.parameter def before_training(self, model, num_epochs): self.iteration = 0 def before_epoch(self, current, num_iterations): self.epoch = current def after_batch(self, metrics): self.iteration += 1 self._update() def after_epoch(self, metrics): self._update() def after_training(self): self.epoch = None self.iteration = None def _update(self): self.parameter = self.schedule( self.parameter, self.epoch, self.iteration, *self.args, **self.kwargs )