Source code for probflow.modules.batch_normalization

from typing import Callable, Dict, List, Type, Union

import probflow.utils.ops as O
from probflow.distributions import Deterministic, Normal
from probflow.parameters import Parameter
from probflow.utils.base import BaseDistribution
from probflow.utils.initializers import xavier

from .module import Module


[docs]class BatchNormalization(Module): r"""A layer which normalizes its inputs. Batch normalization is a technique which normalizes, re-scales, and offsets the output of one layer before passing it on to another layer [1]_. It often leads to faster training of neural networks, and better generalization error by stabilizing the change in the layers' input distributions, or perhaps by smoothing the optimization landscape [2]_. Given a set of tensors for this batch, where :math:`x_{ij}` is the :math:`i`-th element of the :math:`j`-th sample in this batch, this layer returns an elementwise transformation of the input tensors according to: .. math:: \text{BatchNorm}(x_{ij}) = \gamma_i \left( \frac{x_{ij} - \mu_i}{\sigma_i} \right) + \beta_i Where :math:`\mu_i` is the mean of the :math:`i`-th element across the batch: .. math:: \mu_i = \frac{1}{N} \sum_{k=1}^{N} x_{ik} and :math:`\sigma_i` is the standard deviation of the :math:`i`-th element across the batch: .. math:: \sigma_i = \frac{1}{N} \sum_{k=1}^{N} (x_{ik} - \mu_i)^2 and :math:`\gamma` and :math:`\beta` are two free parameters for each element. Parameters ---------- shape : int or list of int or |ndarray| Shape of the tensor to be batch-normalized. name : str Name for this layer. Default = 'BatchNormalization' weight_posterior : |Distribution| Probability distribution class to use to approximate the posterior for the weight parameter(s) (:math:`\gamma`). Default = :class:`.Deterministic` bias_posterior : |Distribution| Probability distribution class to use to approximate the posterior for the bias parameter(s) (:math:`\beta`). Default = :class:`.Deterministic` weight_prior : |None| or a |Distribution| object Prior probability distribution for the weight parameter(s) (:math:`\gamma`). |None| or a |Distribution| function which has been instantiated with parameters. Default = :class:`.Normal` ``(0,1)`` bias_prior : |None| or a |Distribution| object Prior probability distribution for the bias parameter(s) (:math:`\beta`). |None| or a |Distribution| function which has been instantiated with parameters. Default = :class:`.Normal` ``(0,1)`` weight_initializer : dict of callables Initializer functions to use for each variable of the variational posterior distribution for the weights (:math:`\gamma`). Keys correspond to variable names (arguments to the distribution), and values contain functions to initialize those variables given ``shape`` as the single argument. bias_initializer : dict of callables Initializer functions to use for each variable of the variational posterior distribution for the biases (:math:`\beta`). Keys correspond to variable names (arguments to the distribution), and values contain functions to initialize those variables given ``shape`` as the single argument. Examples -------- Batch normalize the output of a :class:`.Dense` layer: .. code-block:: python import probflow as pf network = pf.Sequential([ pf.Dense(d_in=7, d_out=100, bias=False), pf.BatchNormalization(100), tf.nn.relu, pf.Dense(d_in=100, d_out=1) ]) ... References ---------- .. [1] Sergey Ioffe and Christian Szegedy. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. *arXiv preprint*, 2015. http://arxiv.org/abs/1502.03167 .. [2] Shibani Santurkar, Dimitris Tsipras, Andrew Ilyas, and Aleksander Madry. How Does Batch Normalization Help Optimization? *arXiv preprint*, 2018. http://arxiv.org/abs/1805.11604 """ def __init__( self, shape: Union[int, List[int]], weight_posterior: Type[BaseDistribution] = Deterministic, bias_posterior: Type[BaseDistribution] = Deterministic, weight_prior: BaseDistribution = Normal(0, 1), bias_prior: BaseDistribution = Normal(0, 1), weight_initializer: Dict[str, Callable] = {"loc": xavier}, bias_initializer: Dict[str, Callable] = {"loc": xavier}, name="BatchNormalization", ): # Add the batch dimension if isinstance(shape, int): shape = [1, shape] else: shape = [1] + shape # Create the parameters self.weight = Parameter( shape=shape, posterior=weight_posterior, prior=weight_prior, initializer=weight_initializer, name=name + "_weight", ) self.bias = Parameter( shape=shape, posterior=bias_posterior, prior=bias_prior, initializer=bias_initializer, name=name + "_bias", ) def __call__(self, x): """Perform the forward pass""" mean = O.mean(x, axis=-2, keepdims=True) std = O.std(x, axis=-2, keepdims=True) return self.weight() * (x - mean) / std + self.bias()