Callbacks¶
The callbacks module contains classes for monitoring and adjusting the training process.
Callback
- abstract base class for all callbacksEarlyStopping
- stop training if some metric stops improvingKLWeightScheduler
- set the KL weight by epochLearningRateScheduler
- set the learning rate by epochMonitorELBO
- record the ELBO loss over the course of trainingMonitorMetric
- record a metric over the course of trainingMonitorParameter
- record a parameter over the course of trainingTimeOut
- stop training after a certain amount of time
- class probflow.callbacks.Callback(*args)[source]¶
Bases:
probflow.utils.base.BaseCallback
Base class for all callbacks.
See the user guide section on Callbacks.
- class probflow.callbacks.EarlyStopping(metric_fn, patience=0, verbose=True, name='EarlyStopping')[source]¶
Bases:
probflow.callbacks.callback.Callback
Stop training early when some metric stops decreasing
- Parameters
metric_fn (callable, MonitorMetric, or MonitorELBO) – Any arbitrary function, or a
MonitorMetric
orMonitorELBO
callback. Training will be stopped when the value returned by that function stops decreasing (or, ifmetric_fn
was aMonitorMetric
orMonitorELBO
, training is stopped when the metric being monitored or the ELBO stops decreasing.patience (int) – Number of epochs to allow training to continue even if metric is not decreasing. Default is 0.
restore_best_weights (bool) – Whether or not to restore the weights from the best epoch after training is stopped. Default = False.
verbose (bool) – Whether to print when training was stopped. Default = False
name (str) – Name for this callback
Example
See the user guide section on Ending training when a metric stops improving.
- on_epoch_start()¶
Will be called at the start of each training epoch. By default does nothing.
- on_train_end()¶
Will be called at the end of training. By default does nothing.
- on_train_start()¶
Will be called at the start of training. By default does nothing.
- class probflow.callbacks.KLWeightScheduler(fn, verbose=False)[source]¶
Bases:
probflow.callbacks.callback.Callback
Set the weight of the KL term’s contribution to the ELBO loss each epoch
- Parameters
fn (callable) – Function which takes the current epoch as an argument and returns a kl weight, a float between 0 and 1
verbose (bool) – Whether to print the KL weight each epoch (if True) or not (if False). Default = False
Examples
See the user guide section on Changing the KL weight over training.
- plot(**kwargs)[source]¶
Plot the KL weight as a function of epoch
- Parameters
**kwargs – Additional keyword arguments are passed to plt.plot
- on_epoch_end()¶
Will be called at the end of each training epoch. By default does nothing.
- on_train_end()¶
Will be called at the end of training. By default does nothing.
- on_train_start()¶
Will be called at the start of training. By default does nothing.
- class probflow.callbacks.LearningRateScheduler(fn, verbose: bool = False)[source]¶
Bases:
probflow.callbacks.callback.Callback
Set the learning rate as a function of the current epoch
- Parameters
fn (callable) – Function which takes the current epoch as an argument and returns a learning rate.
verbose (bool) – Whether to print the learning rate each epoch (if True) or not (if False). Default = False
Examples
See the user guide section on Changing the learning rate over training. training`.
- plot(**kwargs)[source]¶
Plot the learning rate as a function of epoch
- Parameters
**kwargs – Additional keyword arguments are passed to matplotlib.pyplot.plot
- on_epoch_end()¶
Will be called at the end of each training epoch. By default does nothing.
- on_train_end()¶
Will be called at the end of training. By default does nothing.
- on_train_start()¶
Will be called at the start of training. By default does nothing.
- class probflow.callbacks.MonitorELBO(verbose=False)[source]¶
Bases:
probflow.callbacks.callback.Callback
Monitor the ELBO on the training data
- Parameters
verbose (bool) – Whether to print the average ELBO at the end of every training epoch (if True) or not (if False). Default = False
Example
See the user guide section on Monitoring the loss.
- plot(x='epoch', **kwargs)[source]¶
Plot the ELBO as a function of epoch
- Parameters
x (str {'epoch' or 'time'}) – Whether to plot the metric as a function of epoch or wall time Default is to plot by epoch.
**kwargs – Additional keyword arguments are passed to plt.plot
- on_train_end()¶
Will be called at the end of training. By default does nothing.
- on_train_start()¶
Will be called at the start of training. By default does nothing.
- class probflow.callbacks.MonitorMetric(metric, x, y=None, verbose=False)[source]¶
Bases:
probflow.callbacks.callback.Callback
Monitor some metric on validation data
- Parameters
metric (str) – Name of the metric to evaluate. See
Model.metric()
for a list of available metrics.x (
ndarray
orDataFrame
orSeries
or Tensor orDataGenerator
) – Independent variable values of the validation dataset to evaluate (aka the “features”). Or aDataGenerator
to generate both x and y.y (
ndarray
orDataFrame
orSeries
or Tensor) – Dependent variable values of the validation dataset to evaluate (aka the “target”).verbose (bool) – Whether to print the average ELBO at the end of every training epoch (if True) or not (if False). Default = False
Example
See the user guide section on Monitoring a metric.
- on_train_end()¶
Will be called at the end of training. By default does nothing.
- on_train_start()¶
Will be called at the start of training. By default does nothing.
- class probflow.callbacks.MonitorParameter(params)[source]¶
Bases:
probflow.callbacks.callback.Callback
Monitor the mean value of Parameter(s) over the course of training
Examples
See the user guide section on Monitoring the value of parameter(s).
- on_epoch_start()¶
Will be called at the start of each training epoch. By default does nothing.
- on_train_end()¶
Will be called at the end of training. By default does nothing.
- on_train_start()¶
Will be called at the start of training. By default does nothing.
- class probflow.callbacks.TimeOut(time_limit, verbose=True)[source]¶
Bases:
probflow.callbacks.callback.Callback
Stop training after a certain amount of time
- Parameters
Example
Stop training after five hours:
time_out = pf.callbacks.TimeOut(5*60*60) model.fit(x, y, callbacks=[time_out])
- on_train_end()¶
Will be called at the end of training. By default does nothing.
- on_train_start()¶
Will be called at the start of training. By default does nothing.