Source code for probflow.callbacks.monitor_metric

import time

import matplotlib.pyplot as plt
import numpy as np

from probflow.data import make_generator
from probflow.utils.metrics import get_metric_fn

from .callback import Callback


[docs]class MonitorMetric(Callback): """Monitor some metric on validation data Parameters ---------- metric : str Name of the metric to evaluate. See :meth:`.Model.metric` for a list of available metrics. x : |ndarray| or |DataFrame| or |Series| or Tensor or |DataGenerator| Independent variable values of the validation dataset to evaluate (aka the "features"). Or a |DataGenerator| to generate both x and y. y : |ndarray| or |DataFrame| or |Series| or Tensor Dependent variable values of the validation dataset to evaluate (aka the "target"). verbose : bool Whether to print the average ELBO at the end of every training epoch (if True) or not (if False). Default = False Example ------- See the user guide section on :ref:`monitoring-a-metric`. """ def __init__(self, metric, x, y=None, verbose=False): # Store metric self.metric_fn = get_metric_fn(metric) if isinstance(metric, str): self.metric_name = metric else: self.metric_name = self.metric_fn.__name__ # Store validation data self.data = make_generator(x, y) # Store metrics and epochs self.current_metric = np.nan self.current_epoch = 0 self.metrics = [] self.epochs = [] self.verbose = verbose self.start_time = None self.wall_times = []
[docs] def on_epoch_start(self): """Record start time at the beginning of the first epoch""" if self.start_time is None: self.start_time = time.time()
[docs] def on_epoch_end(self): """Compute the metric on validation data at the end of each epoch.""" self.current_metric = self.model.metric(self.metric_fn, self.data) self.current_epoch += 1 self.metrics += [self.current_metric] self.epochs += [self.current_epoch] self.wall_times += [time.time() - self.start_time] if self.verbose: print( "Epoch {} \t{}: {}".format( self.current_epoch, self.metric_name, self.current_metric ) )
[docs] def plot(self, x="epoch", **kwargs): """Plot the metric being monitored as a function of epoch Parameters ---------- x : str {'epoch' or 'time'} Whether to plot the metric as a function of epoch or wall time. Default is to plot by epoch. **kwargs Additional keyword arguments are passed to plt.plot """ if x == "time": plt.plot(self.wall_times, self.metrics, **kwargs) plt.xlabel("Time (s)") else: plt.plot(self.epochs, self.metrics, **kwargs) plt.xlabel("Epoch") plt.ylabel(self.metric_name)