Source code for probflow.models.discrete_model

import matplotlib.pyplot as plt
import numpy as np

from probflow.utils.plotting import plot_discrete_dist

from .continuous_model import ContinuousModel

[docs]class DiscreteModel(ContinuousModel): """Abstract base class for probflow models where the dependent variable (the target) is discrete (e.g. drawn from a Poisson distribution). TODO : why use this over just Model This class inherits several methods from :class:`.Module`: * :attr:`~parameters` * :attr:`~modules` * :attr:`~trainable_variables` * :meth:`~kl_loss` * :meth:`~kl_loss_batch` * :meth:`~reset_kl_loss` * :meth:`~add_kl_loss` as well as several methods from :class:`.Model`: * :meth:`~log_likelihood` * :meth:`~train_step` * :meth:`~fit` * :meth:`~stop_training` * :meth:`~set_learning_rate` * :meth:`~predictive_sample` * :meth:`~aleatoric_sample` * :meth:`~epistemic_sample` * :meth:`~predict` * :meth:`~metric` * :meth:`~posterior_mean` * :meth:`~posterior_sample` * :meth:`~posterior_ci` * :meth:`~prior_sample` * :meth:`~posterior_plot` * :meth:`~prior_plot` * :meth:`~log_prob` * :meth:`~prob` * :meth:`~save` * :meth:`~summary` as well as several methods from :class:`.ContinuousModel`: * :meth:`~predictive_interval` * :meth:`~predictive_prc` * :meth:`~pred_dist_covered` * :meth:`~pred_dist_coverage` * :meth:`~coverage_by` * :meth:`~residuals` but overrides the following discrete-model-specific methods: * :meth:`~pred_dist_plot` * :meth:`~residuals_plot` Note that :class:`.DiscreteModel` does *not* implement :meth:`~r_squared` or :meth:`~r_squared_plot`. Example ------- TODO """
[docs] def pred_dist_plot(self, x, n=10000, cols=1, **kwargs): """Plot posterior predictive distribution from the model given ``x``. TODO: Docs... Parameters ---------- x : |ndarray| or |DataFrame| or |Series| or |DataGenerator| Independent variable values of the dataset to evaluate (aka the "features"). n : int Number of samples to draw from the model given ``x``. Default = 10000 cols : int Divide the subplots into a grid with this many columns (if ``individually=True``. **kwargs Additional keyword arguments are passed to :func:`.plot_discrete_dist` """ # Sample from the predictive distribution samples = self.predictive_sample(x, n=n) # Independent variable must be scalar Ns = samples.shape[0] N = samples.shape[1] if samples.ndim > 2 and any(e > 1 for e in samples.shape[2:]): raise NotImplementedError( "only discrete dependent variables are " "supported" ) else: samples = samples.reshape([Ns, N]) # Plot the predictive distributions rows = np.ceil(N / cols) for i in range(N): plt.subplot(rows, cols, i + 1) plot_discrete_dist(samples[:, i]) plt.xlabel("Datapoint " + str(i)) plt.tight_layout()
[docs] def r_squared(self, *args, **kwargs): """Cannot compute R squared for a discrete model""" raise RuntimeError("Cannot compute R squared for a discrete model")
[docs] def r_squared_plot(self, *args, **kwargs): """Cannot compute R squared for a discrete model""" raise RuntimeError("Cannot compute R squared for a discrete model")