Source code for pyblaze.nn.callbacks.early_stopping

import copy
from .base import TrainingCallback, CallbackException

[docs]class EarlyStopping(TrainingCallback): """ The early stopping callback watches a specified metric and interrupts training if the metric does not decrease for a specified number of epochs. """
[docs] def __init__(self, metric='val_loss', patience=5, restore_best=False, minimize=True): """ Initializes a new early stopping callback. Parameters ---------- metric: str or list of str, default: 'val_loss' The metric to watch during training. If a list is given, the sum of the given metrics is considered. patience: int, default: 5 The number of epochs that training still continues although the watched metric is greater than its smallest value during training. restore_best: bool, default: False Whether the model's parameter should be set to the parameters which showed the best performance in terms of the watched metric. minimize: bool, default: True Whether to minimize or maximize the given metric. """ self.patience = patience self.epoch = None self.counter = None self.best_metric = None self.model = None self.state_dict = None self.restore_best = restore_best self.metric = metric self.minimize = minimize
def before_training(self, model, num_epochs): self.model = model if self.restore_best: self.state_dict = copy.deepcopy(model.state_dict()) self.epoch = 0 self.counter = 0 self.best_metric = float('inf') if self.minimize else -float('inf') def after_epoch(self, metrics): prev_epoch = self.epoch self.epoch += 1 try: is_better = self._is_metric_better(metrics) except KeyError: if prev_epoch > 0: # In this case, we can ignore the key error and just skip -- the engine does not # perform evaluation on every iteration return if is_better: if self.restore_best: self.state_dict = copy.deepcopy(self.model.state_dict()) self.counter = 0 self.best_metric = self._current_metric(metrics) else: self.counter += 1 if self.counter == self.patience: raise CallbackException( f"Early stopping after epoch {self.epoch} (patience {self.patience}).", verbose=True ) def after_training(self): if self.restore_best and self.counter > 0: self.model.load_state_dict(self.state_dict) def _is_metric_better(self, metrics): if self.minimize: return self._current_metric(metrics) < self.best_metric return self._current_metric(metrics) > self.best_metric def _current_metric(self, metrics): if isinstance(self.metric, list): return sum(metrics[m] for m in self.metric) return metrics[self.metric]