Source code for probflow.parameters.dirichlet_parameter

from typing import List, Union

import probflow.utils.ops as O
from probflow.distributions import Dirichlet
from probflow.utils.initializers import pos_xavier

from .parameter import Parameter

[docs]class DirichletParameter(Parameter): r"""Dirichlet parameter. This is a convenience class for creating a parameter :math:`\theta` with a Dirichlet posterior: .. math:: \theta \sim \text{Dirichlet}(\mathbf{\alpha}) By default, a uniform Dirichlet prior is used: .. math:: \theta \sim \text{Dirichlet}_K(\mathbf{1}/K) TODO: explain that a sample is a categorical prob dist (as compared to CategoricalParameter, where a sample is a single value) Parameters ---------- k : int > 2 Number of categories. shape : int or List[int] Shape of the array containing the parameters. Default = ``1`` posterior : |Distribution| class Probability distribution class to use to approximate the posterior. Default = :class:`.Dirichlet` prior : |Distribution| object Prior probability distribution function which has been instantiated with parameters. Default = :class:`.Dirichlet` ``(1)`` transform : callable Transform to apply to the random variable. Default is to use no transform. initializer : Dict[str, callable] Initializer functions to use for each variable of the variational posterior distribution. Keys correspond to variable names (arguments to the distribution), and values contain functions to initialize those variables given ``shape`` as the single argument. var_transform : Dict[str, callable] Transform to apply to each variable of the variational posterior. name : str Name of the parameter(s). Default = ``'DirichletParameter'`` Examples -------- TODO: creating variable """ def __init__( self, k: int = 2, shape: Union[int, List[int]] = [], posterior=Dirichlet, prior=None, transform=None, initializer={"concentration": pos_xavier}, var_transform={"concentration": O.softplus}, name="DirichletParameter", ): # Check type of k if not isinstance(k, int): raise TypeError("k must be an integer") if k < 2: raise ValueError("k must be >1") # Make shape a list if isinstance(shape, int): shape = [shape] # Create shape of underlying variable array shape = shape + [k] # Use a uniform prior if prior is None: prior = Dirichlet(O.ones(shape)) # Initialize the parameter super().__init__( shape=shape, posterior=posterior, prior=prior, transform=transform, initializer=initializer, var_transform=var_transform, name=name, )