import time
import matplotlib.pyplot as plt
import numpy as np
from .callback import Callback
[docs]class MonitorELBO(Callback):
"""Monitor the ELBO on the training data
Parameters
----------
verbose : bool
Whether to print the average ELBO at the end of every training epoch
(if True) or not (if False). Default = False
Example
-------
See the user guide section on :ref:`monitoring-the-loss`.
"""
def __init__(self, verbose=False):
self.current_elbo = np.nan
self.current_epoch = 0
self.elbos = []
self.epochs = []
self.verbose = verbose
self.start_time = None
self.wall_times = []
[docs] def on_epoch_start(self):
"""Record start time at the beginning of the first epoch"""
if self.start_time is None:
self.start_time = time.time()
[docs] def on_epoch_end(self):
"""Store the ELBO at the end of each epoch."""
self.current_elbo = self.model.get_elbo()
self.current_epoch += 1
self.elbos += [self.current_elbo]
self.epochs += [self.current_epoch]
self.wall_times += [time.time() - self.start_time]
if self.verbose:
print(
"Epoch {} \tELBO: {}".format(
self.current_epoch, self.current_elbo
)
)
[docs] def plot(self, x="epoch", **kwargs):
"""Plot the ELBO as a function of epoch
Parameters
----------
x : str {'epoch' or 'time'}
Whether to plot the metric as a function of epoch or wall time
Default is to plot by epoch.
**kwargs
Additional keyword arguments are passed to plt.plot
"""
if x == "time":
plt.plot(self.wall_times, self.elbos, **kwargs)
plt.xlabel("Time (s)")
else:
plt.plot(self.epochs, self.elbos, **kwargs)
plt.xlabel("Epoch")
plt.ylabel("ELBO")