Source code for probflow.distributions.one_hot_categorical

from probflow.utils.base import BaseDistribution
from probflow.utils.settings import get_backend
from probflow.utils.validation import ensure_tensor_like


[docs]class OneHotCategorical(BaseDistribution): r"""The Categorical distribution, parameterized by categories-len vectors. TODO: explain TODO: example image of the distribution TODO: logits vs probs Parameters ---------- logits : int, float, |ndarray|, or Tensor Logit-transformed category probabilities probs : int, float, |ndarray|, or Tensor """ def __init__(self, logits=None, probs=None): # Check input if logits is None and probs is None: raise TypeError("either logits or probs must be specified") if logits is None: ensure_tensor_like(probs, "probs") if probs is None: ensure_tensor_like(logits, "logits") # Store args self.logits = logits self.probs = probs def __call__(self): """Get the distribution object from the backend""" if get_backend() == "pytorch": import torch.distributions as tod return tod.one_hot_categorical.OneHotCategorical( logits=self["logits"], probs=self["probs"] ) else: from tensorflow_probability import distributions as tfd return tfd.OneHotCategorical( logits=self["logits"], probs=self["probs"] )