from abc import ABC, abstractmethod
[docs]class CallbackException(Exception):
"""
Exception raised by callbacks to stop the training procedure.
"""
[docs] def __init__(self, message, verbose=False):
super().__init__(message)
self.verbose = verbose
[docs] def print(self):
"""
Prints the message of this exception if it is verbose. Otherwise it is a no-op.
"""
if self.verbose:
print(self)
[docs]class TrainingCallback(ABC):
"""
Abstract class to be subclassed by all training callbacks. These callbacks are passed to an
engine which calls the implemented methods at specific points during training.
"""
[docs] def before_training(self, model, num_epochs):
"""
Method is called prior to the start of the training. This method must not raise exceptions.
Parameters
----------
model: torch.nn.Module
The model which is trained.
num_epochs: int
The maximum number of epochs performed during training.
"""
[docs] def before_epoch(self, current, num_iterations):
"""
Method is called at the begining of every epoch during training. This method may raise
exceptions if training should be stopped. Note, however, that stopping training at this
stage is an advanced scenario.
Parameters
----------
current: int
The index of the epoch that is about to start.
num_iterations: int
The expected number of iterations for the batch.
"""
[docs] def after_batch(self, metrics):
"""
Method is called at the end of a mini-batch. If the data is not partitioned into batches,
this function is never called. The method may not raise exceptions.
Parameters
----------
metrics: float or tuple or dict
The metrics obtained from the last mini-batch. This is the value that is returned from
an engine's :meth:`train_batch` method.
"""
[docs] def after_epoch(self, metrics):
"""
Method is called at the end of every epoch during training. This method may raise
exceptions if training should be stopped.
Parameters
----------
metrics: float or tuple or dict
Metrics obtained after training this epoch.
"""
[docs] def after_training(self):
"""
Method is called upon end of the training procedure. The method may not raise exceptions.
"""
[docs]class ValueTrainingCallback(TrainingCallback, ABC):
"""
A training callback with an additional method that can be used to obtain a dynamically
changing value from the callback.
"""
[docs] @abstractmethod
def read(self):
"""
Returns the value that this training callback stores.
"""
[docs]class PredictionCallback(ABC):
"""
Abstract class to be subclassed by all prediction callbacks. These callbacks are passed to an
engine which calls the implemented methods at specific points during inference.
"""
[docs] def before_predictions(self, model, num_iterations):
"""
Called before prediction making starts.
Parameters
----------
model: torch.nn.Module
The model which is used to make predictions.
num_iterations: int
The number of iterations/batches performed for prediction.
"""
[docs] def after_batch(self, *args):
"""
Called after prediction is done for one batch.
Parameters
----------
args: varargs
Usually empty, just to be able to implement both TrainingCallback and
PredictionCallback.
"""
[docs] def after_predictions(self):
"""
Called after all predictions have been made.
"""