Source code for probflow.utils.base

"""
The utils.base module contains abstract base classes (ABCs) for all of
ProbFlow’s classes.

"""


__all__ = [
    "BaseDistribution",
    "BaseParameter",
    "BaseModule",
    "BaseDataGenerator",
    "BaseCallback",
]


from abc import ABC, abstractmethod
from math import ceil

from probflow.utils.casting import to_tensor
from probflow.utils.settings import get_backend


[docs]class BaseDistribution(ABC): """Abstract base class for ProbFlow Distributions""" @abstractmethod def __init__(self, *args): """Initialize the distribution""" def __call__(self): """Get the distribution object from the backend""" def __getitem__(self, key): """Get a parameter, or if a probflow.Parameter, get a sample""" param = getattr(self, key) if callable(param): return param() else: return param
[docs] def prob(self, y): """Compute the probability of some data given this distribution""" if get_backend() == "pytorch": return self().log_prob(to_tensor(y)).exp() else: return self().prob(to_tensor(y))
[docs] def log_prob(self, y): """Compute the log probability of some data given this distribution""" return self().log_prob(to_tensor(y))
[docs] def cdf(self, y): """Cumulative probability of some data along this distribution""" return self().cdf(to_tensor(y))
[docs] def mean(self): """Compute the mean of this distribution Note that this uses the mode of distributions for which the mean is undefined (for example, a categorical distribution)""" if get_backend() == "pytorch": return self().mean else: try: return self().mean() except NotImplementedError: return self().mode()
[docs] def mode(self): """Compute the mode of this distribution""" if get_backend() == "pytorch": raise NotImplementedError else: return self().mode()
[docs] def sample(self, n=1): """Generate a random sample from this distribution""" if get_backend() == "pytorch": try: if isinstance(n, int) and n == 1: return self().rsample() elif isinstance(n, int): return self().rsample([n]) else: return self().rsample(n) except NotImplementedError: if isinstance(n, int) and n == 1: return self().sample() elif isinstance(n, int): return self().sample([n]) else: return self().sample(n) else: if isinstance(n, int) and n == 1: return self().sample() else: return self().sample(n)
[docs]class BaseParameter(ABC): """Abstract base class for ProbFlow Parameters""" @abstractmethod def __init__(self, *args): """Initialize the parameter""" @abstractmethod def __call__(self): """Return a sample from or the MAP estimate of this parameter."""
[docs] @abstractmethod def kl_loss(self): """Compute the sum of the Kullback–Leibler divergences between this parameter's priors and its variational posteriors."""
[docs] @abstractmethod def posterior_mean(self): """Get the mean of the posterior distribution(s)."""
[docs] @abstractmethod def posterior_sample(self): """Get the mean of the posterior distribution(s)."""
[docs] @abstractmethod def prior_sample(self): """Get the mean of the posterior distribution(s)."""
[docs]class BaseModule(ABC): """Abstract base class for ProbFlow Modules""" @abstractmethod def __init__(self, *args): """Initialize the module (abstract method)""" @abstractmethod def __call__(self): """Perform forward pass (abstract method)"""
[docs]class BaseDataGenerator(ABC): """Abstract base class for ProbFlow DataGenerators""" @abstractmethod def __init__(self, *args): """Initialize the data generator"""
[docs] def on_epoch_start(self): """Will be called at the start of each training epoch"""
[docs] def on_epoch_end(self): """Will be called at the end of each training epoch"""
@property @abstractmethod def n_samples(self): """Number of samples in the dataset""" @property @abstractmethod def batch_size(self): """Number of samples to generate each minibatch""" def __len__(self): """Number of batches per epoch""" return int(ceil(self.n_samples / self.batch_size)) @abstractmethod def __getitem__(self, index): """Generate one batch of data""" @abstractmethod def __iter__(self): """Get an iterator over batches""" @abstractmethod def __next__(self): """Get the next batch"""
[docs]class BaseCallback(ABC): """Abstract base class for ProbFlow Callbacks""" # Reference to the model model = None @abstractmethod def __init__(self, *args): """Initialize the callback"""
[docs] @abstractmethod def on_epoch_start(self): """Will be called at the start of each training epoch"""
[docs] @abstractmethod def on_epoch_end(self): """Will be called at the end of each training epoch"""
[docs] @abstractmethod def on_train_end(self): """Will be called at the end of training"""