Utils

The utils module contains utility classes, functions, and settings which ProbFlow uses internally. The sub-modules of utils are:

  • settings - backend, datatype, and sampling settings

  • base - abstract base classes for ProbFlow objects

  • ops - backend-independent mathematical operations

  • casting - backend-independent casting operations

  • initializers - backend-independent variable initializer functions

  • io - functions for loading and saving models

  • metrics - functions for computing various model performance metrics

  • plotting - functions for plotting distributions, posteriors, etc

  • torch_distributions - manual implementations of missing torch dists

  • validation - functions for data type validation

Settings

The utils.settings module contains global settings about the backend to use, what sampling method to use, the default device, and default datatype.

Backend

Which backend to use. Can be either TensorFlow 2.0 or PyTorch.

Datatype

Which datatype to use as the default for parameters. Depending on your model, you might have to set the default datatype to match the datatype of your data.

Samples

Whether and how many samples to draw from parameter posterior distributions. If None, the maximum a posteriori estimate of each parameter will be used. If an integer greater than 0, that many samples from each parameter’s posterior distribution will be used.

Flipout

Whether to use Flipout where possible.

Static posterior sampling

Whether or not to use static posterior sampling (i.e., take a random sample from the posterior, but take the same random sample on repeated calls), and control the UUID of the current static sampling regime.

Sampling context manager

A context manager which controls how Parameters sample from their variational distributions while inside the context manager.

probflow.utils.settings.get_backend()[source]

Get which backend is currently being used.

Returns

backend – The current backend

Return type

str {‘tensorflow’ or ‘pytorch’}

probflow.utils.settings.set_backend(backend)[source]

Set which backend is currently being used.

Parameters

backend (str {'tensorflow' or 'pytorch'}) – The backend to use

probflow.utils.settings.get_datatype()[source]

Get the default datatype used for Tensors

Returns

dtype – The current default datatype

Return type

tf.dtype or torch.dtype

probflow.utils.settings.set_datatype(datatype)[source]

Set the datatype to use for Tensors

Parameters

datatype (tf.dtype or torch.dtype) – The default datatype to use

probflow.utils.settings.get_samples()[source]

Get how many samples (if any) are being drawn from parameter posteriors

Returns

n – Number of samples (if any) to draw from parameters’ posteriors. Default = None (ie, use the Maximum a posteriori estimate)

Return type

None or int > 0

probflow.utils.settings.set_samples(samples)[source]

Set how many samples (if any) to draw from parameter posteriors

Parameters

samples (None or int > 0) – Number of samples (if any) to draw from parameters’ posteriors.

probflow.utils.settings.get_flipout()[source]

Get whether flipout is currently being used where possible.

Returns

flipout – Whether flipout is currently being used where possible while sampling during training.

Return type

bool

probflow.utils.settings.set_flipout(flipout)[source]

Set whether to use flipout where possible while sampling during training

Parameters

flipout (bool) – Whether to use flipout where possible while sampling during training.

probflow.utils.settings.get_static_sampling_uuid()[source]

Get the current static sampling UUID

probflow.utils.settings.set_static_sampling_uuid(uuid_value)[source]

Set the current static sampling UUID

class probflow.utils.settings.Sampling(n=None, flipout=None, static=None)[source]

Bases: object

Use sampling while within this context manager.

Keyword Arguments
  • n (None or int > 0) – Number of samples (if any) to draw from parameters’ posteriors. Default = 1

  • flipout (bool) – Whether to use flipout where possible while sampling during training. Default = False

Example

To use maximum a posteriori estimates of the parameter values, don’t use the sampling context manager:

>>> import probflow as pf
>>> param = pf.Parameter()
>>> param()
[0.07226744]
>>> param() # MAP estimate is always the same
[0.07226744]

To use a single sample, use the sampling context manager with n=1:

>>> with pf.Sampling(n=1):
>>>     param()
[-2.2228503]
>>> with pf.Sampling(n=1):
>>>     param() #samples are different
[1.3473024]

To use multiple samples, use the sampling context manager and set the number of samples to take with the n keyword argument:

>>> with pf.Sampling(n=3):
>>>     param()
[[ 0.10457394]
 [ 0.14018342]
 [-1.8649881 ]]
>>> with pf.Sampling(n=5):
>>>     param()
[[ 2.1035051]
 [-2.641631 ]
 [-2.9091313]
 [ 3.5294306]
 [ 1.6596333]]

To use static samples - that is, to always return the same samples while in the same context manager - use the sampling context manager with the static keyword argument set to True:

>>> with pf.Sampling(static=True):
>>>     param()
[ 0.10457394]
>>>     param()  # repeated samples yield the same value
[ 0.10457394]
>>> with pf.Sampling(static=True):
>>>     param()  # under a new context manager they yield new samples
[-2.641631]
>>>     param()  # but remain the same while under the same context
[-2.641631]

Base

The utils.base module contains abstract base classes (ABCs) for all of ProbFlow’s classes.

class probflow.utils.base.BaseDistribution(*args)[source]

Bases: abc.ABC

Abstract base class for ProbFlow Distributions

prob(y)[source]

Compute the probability of some data given this distribution

log_prob(y)[source]

Compute the log probability of some data given this distribution

cdf(y)[source]

Cumulative probability of some data along this distribution

mean()[source]

Compute the mean of this distribution

Note that this uses the mode of distributions for which the mean is undefined (for example, a categorical distribution)

mode()[source]

Compute the mode of this distribution

sample(n=1)[source]

Generate a random sample from this distribution

class probflow.utils.base.BaseParameter(*args)[source]

Bases: abc.ABC

Abstract base class for ProbFlow Parameters

abstract kl_loss()[source]

Compute the sum of the Kullback–Leibler divergences between this parameter’s priors and its variational posteriors.

abstract posterior_mean()[source]

Get the mean of the posterior distribution(s).

abstract posterior_sample()[source]

Get the mean of the posterior distribution(s).

abstract prior_sample()[source]

Get the mean of the posterior distribution(s).

class probflow.utils.base.BaseModule(*args)[source]

Bases: abc.ABC

Abstract base class for ProbFlow Modules

class probflow.utils.base.BaseDataGenerator(*args)[source]

Bases: abc.ABC

Abstract base class for ProbFlow DataGenerators

on_epoch_start()[source]

Will be called at the start of each training epoch

on_epoch_end()[source]

Will be called at the end of each training epoch

abstract property n_samples

Number of samples in the dataset

abstract property batch_size

Number of samples to generate each minibatch

class probflow.utils.base.BaseCallback(*args)[source]

Bases: abc.ABC

Abstract base class for ProbFlow Callbacks

abstract on_epoch_start()[source]

Will be called at the start of each training epoch

abstract on_epoch_end()[source]

Will be called at the end of each training epoch

abstract on_train_end()[source]

Will be called at the end of training

Ops

The utils.ops module contains operations which run using the current backend.


probflow.utils.ops.kl_divergence(P, Q)[source]

Compute the Kullback–Leibler divergence between two distributions.

Parameters
Returns

kld – The Kullback–Leibler divergence between P and Q (KL(P || Q))

Return type

Tensor

probflow.utils.ops.ones(shape)[source]

Tensor full of ones.

probflow.utils.ops.zeros(shape)[source]

Tensor full of zeros.

probflow.utils.ops.full(shape, value)[source]

Tensor full of some value.

probflow.utils.ops.randn(shape)[source]

Tensor full of random values drawn from a standard normal.

probflow.utils.ops.rand_rademacher(shape)[source]

Tensor full of random -1s or 1s (i.e. drawn from a Rademacher dist).

probflow.utils.ops.shape(x)[source]

Get a list of integers representing this tensor’s shape

probflow.utils.ops.eye(dims)[source]

Identity matrix.

probflow.utils.ops.sum(val, axis=- 1, keepdims=False)[source]

The sum.

probflow.utils.ops.prod(val, axis=- 1, keepdims=False)[source]

The product.

probflow.utils.ops.mean(val, axis=- 1, keepdims=False)[source]

The mean.

probflow.utils.ops.std(val, axis=- 1, keepdims=False)[source]

The uncorrected sample standard deviation.

probflow.utils.ops.round(val)[source]

Round to the closest integer

probflow.utils.ops.abs(val)[source]

Absolute value

probflow.utils.ops.square(val)[source]

Power of 2

probflow.utils.ops.sqrt(val)[source]

The square root.

probflow.utils.ops.exp(val)[source]

The natural exponent.

probflow.utils.ops.relu(val)[source]

Linear rectification.

probflow.utils.ops.softplus(val)[source]

Linear rectification.

probflow.utils.ops.sigmoid(val)[source]

Sigmoid function.

probflow.utils.ops.gather(vals, inds, axis=0)[source]

Gather values by index

probflow.utils.ops.cat(vals, axis=0)[source]

Concatenate tensors

probflow.utils.ops.additive_logistic_transform(vals)[source]

The additive logistic transformation

probflow.utils.ops.insert_col_of(vals, val)[source]

Add a column of a value to the left side of a tensor

probflow.utils.ops.new_variable(initial_values)[source]

Get a new variable with the current backend, and initialize it

probflow.utils.ops.log_cholesky_transform(x)[source]

Perform the log cholesky transform on a vector of values.

This turns a vector of \(\frac{N(N+1)}{2}\) unconstrained values into a valid \(N \times N\) covariance matrix.

References

probflow.utils.ops.copy_tensor(x)[source]

Copy a tensor, detaching it from the gradient/backend/etc/etc

Casting

The utils.casting module contains functions for casting back and forth betweeen Tensors and numpy arrays.


probflow.utils.casting.to_numpy(x)[source]

Convert tensor to numpy array

probflow.utils.casting.to_tensor(x)[source]

Make x a tensor if not already

Initializers

Initializers.

Functions to initialize posterior distribution variables.


probflow.utils.initializers.xavier(shape)[source]

Xavier initializer

probflow.utils.initializers.scale_xavier(shape)[source]

Xavier initializer for scale variables

probflow.utils.initializers.pos_xavier(shape)[source]

Xavier initializer for positive variables

probflow.utils.initializers.full_of(val)[source]

Get initializer which returns tensor full of single value

IO

Functions for saving and loading ProbFlow objects

probflow.utils.io.dumps(obj)[source]

Serialize a probflow object to a json-safe string.

Note

This removes the compiled _train_fn attribute of a Model which is either a TensorFlow or PyTorch compiled function to perform a single training step. Cloudpickle can’t serialize it, and after de-serializing will just JIT re-compile if needed.

probflow.utils.io.loads(s)[source]

Deserialize a probflow object from string

probflow.utils.io.dump(obj, filename)[source]

Serialize a probflow object to file

Note

This removes the compiled _train_fn attribute of a Model which is either a TensorFlow or PyTorch compiled function to perform a single training step. Cloudpickle can’t serialize it, and after de-serializing will just JIT re-compile if needed.

probflow.utils.io.load(filename)[source]

Deserialize a probflow object from file

Metrics

Metrics.

Evaluation metrics

  • log_prob()

  • acc()

  • accuracy()

  • mse()

  • sse()

  • mae()


probflow.utils.metrics.get_metric_fn(metric)[source]

Get a function corresponding to a metric string

Plotting

Plotting utilities.

TODO: more info…


probflow.utils.plotting.approx_kde(data, bins=500, bw=0.075)[source]

A fast approximation to kernel density estimation.

probflow.utils.plotting.get_next_color(def_color, ix)[source]

Get the next color in the color cycle

probflow.utils.plotting.get_ix_label(ix, shape)[source]

Get a string representation of the current index

probflow.utils.plotting.plot_dist(data, xlabel='', style='fill', bins=20, ci=0.0, bw=0.075, alpha=0.4, color=None, legend=True)[source]

Plot the distribution of samples.

Parameters
  • data (ndarray) – Samples to plot. Should be of size (Nsamples,…)

  • xlabel (str) – Label for the x axis

  • style (str) –

    Which style of plot to create. Available types are:

    • 'fill' - filled density plot (the default)

    • 'line' - line density plot

    • 'hist' - histogram

  • bins (int or list or ndarray) – Number of bins to use for the histogram (if kde=False), or a list or vector of bin edges.

  • ci (float between 0 and 1) – Confidence interval to plot. Default = 0.0 (i.e., not plotted)

  • bw (float) – Bandwidth of the kernel density estimate (if using style='line' or style='fill'). Default is 0.075

  • alpha (float between 0 and 1) – Transparency of the plot (if style``=’fill’`` or 'hist')

  • color (matplotlib color code or list of them) – Color(s) to use to plot the distribution. See https://matplotlib.org/tutorials/colors/colors.html Default = use the default matplotlib color cycle

  • legend (bool) – Whether to show legends for plots with >1 distribution Default = True

probflow.utils.plotting.plot_line(xdata, ydata, xlabel='', ylabel='', fmt='-', color=None)[source]

Plot lines.

Parameters
  • xdata (ndarray) – X values of points to plot. Should be vector of length Nsamples.

  • ydata (ndarray) – Y vaules of points to plot. Should be of size (Nsamples,...).

  • xlabel (str) – Label for the x axis. Default is no x axis label.

  • ylabel (str) – Label for the y axis. Default is no y axis label.

  • fmt (str or matplotlib linespec) – Line marker to use. Default = '-' (a normal line).

  • color (matplotlib color code or list of them) – Color(s) to use to plot the distribution. See https://matplotlib.org/tutorials/colors/colors.html Default = use the default matplotlib color cycle

probflow.utils.plotting.fill_between(xdata, lb, ub, xlabel='', ylabel='', alpha=0.3, color=None)[source]

Fill between lines.

Parameters
  • xdata (ndarray) – X values of points to plot. Should be vector of length Nsamples.

  • lb (ndarray) – Lower bound of fill. Should be of size (Nsamples,...).

  • ub (ndarray) – Upper bound of fill. Should be same size as lb.

  • xlabel (str) – Label for the x axis. Default is no x axis label.

  • ylabel (str) – Label for the y axis. Default is no y axis label.

  • fmt (str or matplotlib linespec) – Line marker to use. Default = '-' (a normal line).

  • color (matplotlib color code or list of them) – Color(s) to use to plot the distribution. See https://matplotlib.org/tutorials/colors/colors.html Default = use the default matplotlib color cycle

probflow.utils.plotting.centered_text(text)[source]

Display text centered in the figure

probflow.utils.plotting.plot_discrete_dist(x)[source]

Plot histogram of discrete variable

probflow.utils.plotting.plot_categorical_dist(x)[source]

Plot histogram of categorical variable

probflow.utils.plotting.plot_by(x, data, bins=30, func='mean', plot=True, bootstrap=100, ci=0.95, **kwargs)[source]

Compute and plot some function func of data as a function of x.

Parameters
  • x (ndarray) – Coordinates of data to plot

  • data (ndarray) – Data to plot by bins of x

  • bins (int) – Number of bins to bin x into

  • func (callable or str) –

    Function to apply on elements of data in each x bin. Can be a callable or one of the following str:

    • 'count'

    • 'sum'

    • 'mean'

    • 'median'

    Default = 'mean'

  • plot (bool) – Whether to plot data as a function of x Default = False

  • bootstrap (None or int > 0) – Number of bootstrap samples to use for estimating the uncertainty of the true coverage.

  • ci (list of float between 0 and 1) – Bootstrapped confidence interval percentiles of coverage to show.

  • **kwargs – Additional arguments are passed to plt.plot or fill_between

Returns

  • x_o (|ndarray|) – x bin centers

  • data_o (|ndarray|) – func applied to data values in each x bin

Torch Distributions

Torch backend distributions

Validation

The utils.validation module contains functions for checking that inputs have the correct type.


probflow.utils.validation.ensure_tensor_like(obj, name)[source]

Determine whether an object can be cast to a Tensor