"""
The utils.ops module contains operations which run using the current backend.
* :func:`.kl_divergence`
* :func:`.expand_dims`
* :func:`.squeeze`
* :func:`.ones`
* :func:`.zeros`
* :func:`.full`
* :func:`.randn`
* :func:`.rand_rademacher`
* :func:`.shape`
* :func:`.eye`
* :func:`.sum`
* :func:`.prod`
* :func:`.mean`
* :func:`.std`
* :func:`.round`
* :func:`.abs`
* :func:`.square`
* :func:`.sqrt`
* :func:`.exp`
* :func:`.relu`
* :func:`.softplus`
* :func:`.sigmoid`
* :func:`.gather`
* :func:`.cat`
* :func:`.additive_logistic_transform`
* :func:`.insert_col_of`
* :func:`.new_variable`
* :func:`.log_cholesky_transform`
* :func:`.copy_tensor`
----------
"""
__all__ = [
"kl_divergence",
"expand_dims",
"squeeze",
"ones",
"zeros",
"full",
"randn",
"rand_rademacher",
"shape",
"eye",
"sum",
"prod",
"mean",
"std",
"round",
"abs",
"square",
"sqrt",
"exp",
"relu",
"softplus",
"sigmoid",
"gather",
"cat",
"additive_logistic_transform",
"insert_col_of",
"new_variable",
"log_cholesky_transform",
"copy_tensor",
]
from probflow.utils.base import BaseDistribution
from probflow.utils.casting import make_input_tensor, to_tensor
from probflow.utils.settings import get_backend, get_datatype
[docs]def kl_divergence(P, Q):
"""Compute the Kullback–Leibler divergence between two distributions.
Parameters
----------
P : |tfp.Distribution| or |torch.Distribution|
The first distribution
Q : |tfp.Distribution| or |torch.Distribution|
The second distribution
Returns
-------
kld : Tensor
The Kullback–Leibler divergence between P and Q (KL(P || Q))
"""
# Get the backend distribution if needed
if isinstance(P, BaseDistribution):
P = P()
if isinstance(Q, BaseDistribution):
Q = Q()
# Compute KL divergence with the backend
if get_backend() == "pytorch":
import torch
return torch.distributions.kl.kl_divergence(P, Q)
else:
import tensorflow_probability as tfp
return tfp.distributions.kl_divergence(P, Q)
@make_input_tensor
def expand_dims(val, axis):
"""Add a singular dimension to a Tensor"""
if axis is None:
return val
if get_backend() == "pytorch":
import torch
return torch.unsqueeze(val, axis)
else:
import tensorflow as tf
return tf.expand_dims(val, axis)
@make_input_tensor
def squeeze(val):
"""Remove singleton dimensions"""
if get_backend() == "pytorch":
import torch
return torch.squeeze(val)
else:
import tensorflow as tf
return tf.squeeze(val)
[docs]def ones(shape):
"""Tensor full of ones."""
if get_backend() == "pytorch":
import torch
return torch.ones(shape, dtype=get_datatype())
else:
import tensorflow as tf
return tf.ones(shape, dtype=get_datatype())
[docs]def zeros(shape):
"""Tensor full of zeros."""
if get_backend() == "pytorch":
import torch
return torch.zeros(shape, dtype=get_datatype())
else:
import tensorflow as tf
return tf.zeros(shape, dtype=get_datatype())
[docs]def full(shape, value):
"""Tensor full of some value."""
if get_backend() == "pytorch":
import torch
return torch.full(shape, value, dtype=get_datatype())
else:
import tensorflow as tf
return tf.cast(tf.fill(shape, value), dtype=get_datatype())
[docs]def randn(shape):
"""Tensor full of random values drawn from a standard normal."""
if get_backend() == "pytorch":
import torch
return torch.randn(shape, dtype=get_datatype())
else:
import tensorflow as tf
return tf.random.normal(shape, dtype=get_datatype())
[docs]def rand_rademacher(shape):
"""Tensor full of random -1s or 1s (i.e. drawn from a Rademacher dist)."""
if get_backend() == "pytorch":
import torch
return 2 * torch.randint(0, 2, shape, dtype=get_datatype()) - 1
else:
import tensorflow_probability as tfp
try: # for older versions of tfp, fall back on older version
return tfp.random.rademacher(shape)
except AttributeError: # pragma: no cover
return tfp.python.math.random_rademacher(shape)
[docs]def shape(x):
"""Get a list of integers representing this tensor's shape"""
if get_backend() == "pytorch":
return [s for s in x.shape]
else:
return [s for s in x.shape]
[docs]def eye(dims):
"""Identity matrix."""
if get_backend() == "pytorch":
import torch
return torch.eye(dims, dtype=get_datatype())
else:
import tensorflow as tf
return tf.eye(dims, dtype=get_datatype())
[docs]def sum(val, axis=-1, keepdims=False):
"""The sum."""
if get_backend() == "pytorch":
import torch
if axis is None:
return torch.sum(val)
else:
return torch.sum(val, axis, keepdim=keepdims)
else:
import tensorflow as tf
return tf.reduce_sum(val, axis=axis, keepdims=keepdims)
[docs]def prod(val, axis=-1, keepdims=False):
"""The product."""
if get_backend() == "pytorch":
import torch
return torch.prod(val, dim=axis, keepdim=keepdims)
else:
import tensorflow as tf
return tf.reduce_prod(val, axis=axis, keepdims=keepdims)
[docs]def mean(val, axis=-1, keepdims=False):
"""The mean."""
if get_backend() == "pytorch":
import torch
return torch.mean(val, dim=axis, keepdim=keepdims)
else:
import tensorflow as tf
return tf.reduce_mean(val, axis=axis, keepdims=keepdims)
[docs]def std(val, axis=-1, keepdims=False):
"""The uncorrected sample standard deviation."""
if get_backend() == "pytorch":
import torch
return torch.std(val, dim=axis, keepdim=keepdims)
else:
import tensorflow as tf
return tf.math.reduce_std(val, axis=axis, keepdims=keepdims)
[docs]def round(val):
"""Round to the closest integer"""
if get_backend() == "pytorch":
import torch
return torch.round(val)
else:
import tensorflow as tf
return tf.math.round(val)
[docs]def abs(val):
"""Absolute value"""
if get_backend() == "pytorch":
import torch
return torch.abs(val)
else:
import tensorflow as tf
return tf.math.abs(val)
[docs]def square(val):
"""Power of 2"""
if get_backend() == "pytorch":
return val ** 2
else:
import tensorflow as tf
return tf.math.square(val)
[docs]def sqrt(val):
"""The square root."""
if get_backend() == "pytorch":
import torch
return torch.sqrt(val)
else:
import tensorflow as tf
return tf.sqrt(val)
[docs]def exp(val):
"""The natural exponent."""
if get_backend() == "pytorch":
import torch
return torch.exp(val)
else:
import tensorflow as tf
return tf.exp(val)
[docs]def relu(val):
"""Linear rectification."""
if get_backend() == "pytorch":
import torch
return torch.nn.ReLU()(val)
else:
import tensorflow as tf
return tf.nn.relu(val)
[docs]def softplus(val):
"""Linear rectification."""
if get_backend() == "pytorch":
import torch
return torch.nn.Softplus()(val)
else:
import tensorflow as tf
return tf.math.softplus(val)
[docs]def sigmoid(val):
"""Sigmoid function."""
if get_backend() == "pytorch":
import torch
return torch.nn.Sigmoid()(val)
else:
import tensorflow as tf
return tf.math.sigmoid(val)
[docs]def gather(vals, inds, axis=0):
"""Gather values by index"""
if get_backend() == "pytorch":
import torch
return torch.index_select(vals, axis, to_tensor(inds))
else:
import tensorflow as tf
return tf.gather(vals, inds, axis=axis)
[docs]def cat(vals, axis=0):
"""Concatenate tensors"""
if get_backend() == "pytorch":
import torch
return torch.cat(vals, dim=axis)
else:
import tensorflow as tf
return tf.concat(vals, axis=axis)
[docs]def insert_col_of(vals, val):
"""Add a column of a value to the left side of a tensor"""
if get_backend() == "pytorch":
import torch
shape = [s for s in vals.shape[:-1]] + [1]
return torch.cat(
[val * torch.ones(shape, dtype=get_datatype()), vals], dim=-1
)
else:
import tensorflow as tf
shape = tf.concat([vals.shape[:-1], [1]], axis=-1)
return tf.concat(
[val * tf.ones(shape, dtype=get_datatype()), vals], axis=-1
)
[docs]def new_variable(initial_values):
"""Get a new variable with the current backend, and initialize it"""
if get_backend() == "pytorch":
import torch
return torch.nn.Parameter(initial_values)
else:
import tensorflow as tf
return tf.Variable(initial_values)
def transpose(x):
"""Transpose a matrix or batch of matrices"""
if get_backend() == "pytorch":
import torch
return torch.transpose(x, -1, -2)
else:
import tensorflow as tf
perm = list(range(x.ndim))
perm[-1] = x.ndim - 2
perm[-2] = x.ndim - 1
return tf.transpose(x, perm=perm)
def reshape(x, new_shape):
"""Reshape a tensor"""
if get_backend() == "pytorch":
import torch
return torch.reshape(x, tuple(new_shape))
else:
import tensorflow as tf
return tf.reshape(x, new_shape)
[docs]def copy_tensor(x):
"""Copy a tensor, detaching it from the gradient/backend/etc/etc"""
if get_backend() == "pytorch":
return x.detach().clone()
else:
import tensorflow as tf
return tf.identity(x)