Source code for probflow.callbacks.early_stopping

import numpy as np

from .callback import Callback
from .monitor_elbo import MonitorELBO
from .monitor_metric import MonitorMetric


[docs]class EarlyStopping(Callback): """Stop training early when some metric stops decreasing Parameters ---------- metric_fn : callable, MonitorMetric, or MonitorELBO Any arbitrary function, or a :class:`.MonitorMetric` or :class:`.MonitorELBO` callback. Training will be stopped when the value returned by that function stops decreasing (or, if ``metric_fn`` was a :class:`.MonitorMetric` or :class:`.MonitorELBO`, training is stopped when the metric being monitored or the ELBO stops decreasing. patience : int Number of epochs to allow training to continue even if metric is not decreasing. Default is 0. restore_best_weights : bool Whether or not to restore the weights from the best epoch after training is stopped. Default = False. verbose : bool Whether to print when training was stopped. Default = False name : str Name for this callback Example ------- See the user guide section on :ref:`user-guide-early-stopping`. """ def __init__( self, metric_fn, patience=0, verbose=True, name="EarlyStopping" ): # Check types if not isinstance(patience, int): raise TypeError("patience must be an int") if patience < 0: raise ValueError("patience must be non-negative") if not callable(metric_fn) and not isinstance( metric_fn, (MonitorMetric, MonitorELBO) ): raise TypeError( "metric_fn must be a callable, MonitorMetric, or MonitorELBO" ) # Store values self.metric_fn = metric_fn self.patience = patience self.best = np.Inf self.count = 0 self.epoch = 0 self.verbose = verbose self.name = name # TODO: restore_best_weights? using save_model and load_model?
[docs] def on_epoch_end(self): """Stop training if there was no improvement since the last epoch.""" self.epoch += 1 if isinstance(self.metric_fn, MonitorMetric): metric = self.metric_fn.current_metric elif isinstance(self.metric_fn, MonitorELBO): metric = self.metric_fn.current_elbo else: metric = self.metric_fn() if metric < self.best: self.best = metric self.count = 0 else: self.count += 1 if self.count > self.patience: self.model.stop_training() if self.verbose: print(self.name + " after " + str(self.epoch) + " epochs")