Callbacks

The callbacks module contains classes for monitoring and adjusting the training process.


class probflow.callbacks.Callback(*args)[source]

Bases: probflow.utils.base.BaseCallback

Base class for all callbacks.

See the user guide section on Callbacks.

on_train_start()[source]

Will be called at the start of training. By default does nothing.

on_epoch_start()[source]

Will be called at the start of each training epoch. By default does nothing.

on_epoch_end()[source]

Will be called at the end of each training epoch. By default does nothing.

on_train_end()[source]

Will be called at the end of training. By default does nothing.

class probflow.callbacks.EarlyStopping(metric_fn, patience=0, verbose=True, name='EarlyStopping')[source]

Bases: probflow.callbacks.callback.Callback

Stop training early when some metric stops decreasing

Parameters
  • metric_fn (callable, MonitorMetric, or MonitorELBO) – Any arbitrary function, or a MonitorMetric or MonitorELBO callback. Training will be stopped when the value returned by that function stops decreasing (or, if metric_fn was a MonitorMetric or MonitorELBO, training is stopped when the metric being monitored or the ELBO stops decreasing.

  • patience (int) – Number of epochs to allow training to continue even if metric is not decreasing. Default is 0.

  • restore_best_weights (bool) – Whether or not to restore the weights from the best epoch after training is stopped. Default = False.

  • verbose (bool) – Whether to print when training was stopped. Default = False

  • name (str) – Name for this callback

Example

See the user guide section on Ending training when a metric stops improving.

on_epoch_end()[source]

Stop training if there was no improvement since the last epoch.

on_epoch_start()

Will be called at the start of each training epoch. By default does nothing.

on_train_end()

Will be called at the end of training. By default does nothing.

on_train_start()

Will be called at the start of training. By default does nothing.

class probflow.callbacks.KLWeightScheduler(fn, verbose=False)[source]

Bases: probflow.callbacks.callback.Callback

Set the weight of the KL term’s contribution to the ELBO loss each epoch

Parameters
  • fn (callable) – Function which takes the current epoch as an argument and returns a kl weight, a float between 0 and 1

  • verbose (bool) – Whether to print the KL weight each epoch (if True) or not (if False). Default = False

Examples

See the user guide section on Changing the KL weight over training.

on_epoch_start()[source]

Set the KL weight at the start of each epoch.

plot(**kwargs)[source]

Plot the KL weight as a function of epoch

Parameters

**kwargs – Additional keyword arguments are passed to plt.plot

on_epoch_end()

Will be called at the end of each training epoch. By default does nothing.

on_train_end()

Will be called at the end of training. By default does nothing.

on_train_start()

Will be called at the start of training. By default does nothing.

class probflow.callbacks.LearningRateScheduler(fn, verbose: bool = False)[source]

Bases: probflow.callbacks.callback.Callback

Set the learning rate as a function of the current epoch

Parameters
  • fn (callable) – Function which takes the current epoch as an argument and returns a learning rate.

  • verbose (bool) – Whether to print the learning rate each epoch (if True) or not (if False). Default = False

Examples

See the user guide section on Changing the learning rate over training. training`.

on_epoch_start()[source]

Set the learning rate at the start of each epoch.

plot(**kwargs)[source]

Plot the learning rate as a function of epoch

Parameters

**kwargs – Additional keyword arguments are passed to matplotlib.pyplot.plot

on_epoch_end()

Will be called at the end of each training epoch. By default does nothing.

on_train_end()

Will be called at the end of training. By default does nothing.

on_train_start()

Will be called at the start of training. By default does nothing.

class probflow.callbacks.MonitorELBO(verbose=False)[source]

Bases: probflow.callbacks.callback.Callback

Monitor the ELBO on the training data

Parameters

verbose (bool) – Whether to print the average ELBO at the end of every training epoch (if True) or not (if False). Default = False

Example

See the user guide section on Monitoring the loss.

on_epoch_start()[source]

Record start time at the beginning of the first epoch

on_epoch_end()[source]

Store the ELBO at the end of each epoch.

plot(x='epoch', **kwargs)[source]

Plot the ELBO as a function of epoch

Parameters
  • x (str {'epoch' or 'time'}) – Whether to plot the metric as a function of epoch or wall time Default is to plot by epoch.

  • **kwargs – Additional keyword arguments are passed to plt.plot

on_train_end()

Will be called at the end of training. By default does nothing.

on_train_start()

Will be called at the start of training. By default does nothing.

class probflow.callbacks.MonitorMetric(metric, x, y=None, verbose=False)[source]

Bases: probflow.callbacks.callback.Callback

Monitor some metric on validation data

Parameters
  • metric (str) – Name of the metric to evaluate. See Model.metric() for a list of available metrics.

  • x (ndarray or DataFrame or Series or Tensor or DataGenerator) – Independent variable values of the validation dataset to evaluate (aka the “features”). Or a DataGenerator to generate both x and y.

  • y (ndarray or DataFrame or Series or Tensor) – Dependent variable values of the validation dataset to evaluate (aka the “target”).

  • verbose (bool) – Whether to print the average ELBO at the end of every training epoch (if True) or not (if False). Default = False

Example

See the user guide section on Monitoring a metric.

on_epoch_start()[source]

Record start time at the beginning of the first epoch

on_epoch_end()[source]

Compute the metric on validation data at the end of each epoch.

on_train_end()

Will be called at the end of training. By default does nothing.

on_train_start()

Will be called at the start of training. By default does nothing.

plot(x='epoch', **kwargs)[source]

Plot the metric being monitored as a function of epoch

Parameters
  • x (str {'epoch' or 'time'}) – Whether to plot the metric as a function of epoch or wall time. Default is to plot by epoch.

  • **kwargs – Additional keyword arguments are passed to plt.plot

class probflow.callbacks.MonitorParameter(params)[source]

Bases: probflow.callbacks.callback.Callback

Monitor the mean value of Parameter(s) over the course of training

Parameters

params (str or List[str] or None) – Name(s) of the parameters to monitor.

Examples

See the user guide section on Monitoring the value of parameter(s).

on_epoch_end()[source]

Store mean values of Parameter(s) at the end of each epoch.

plot(param=None, **kwargs)[source]

Plot the parameter value(s) as a function of epoch

Parameters

param (None or str) – Parameter to plot. If None, assumes we’ve only been monitoring one parameter and plots that. If a str, plots the parameter with that name (assuming we’ve been monitoring it).

on_epoch_start()

Will be called at the start of each training epoch. By default does nothing.

on_train_end()

Will be called at the end of training. By default does nothing.

on_train_start()

Will be called at the start of training. By default does nothing.

class probflow.callbacks.TimeOut(time_limit, verbose=True)[source]

Bases: probflow.callbacks.callback.Callback

Stop training after a certain amount of time

Parameters
  • time_limit (float or int) – Number of seconds after which to stop training

  • verbose (bool) – Whether to print that we stopped training early (if True) or not (if False). Default = False

Example

Stop training after five hours:

time_out = pf.callbacks.TimeOut(5*60*60)
model.fit(x, y, callbacks=[time_out])
on_epoch_start()[source]

Record start time at the beginning of the first epoch

on_epoch_end()[source]

Stop training if time limit has been passed

on_train_end()

Will be called at the end of training. By default does nothing.

on_train_start()

Will be called at the start of training. By default does nothing.