Source code for probflow.utils.settings

"""
The utils.settings module contains global settings about the backend to use,
what sampling method to use, the default device, and default datatype.


Backend
-------

Which backend to use.  Can be either
`TensorFlow 2.0 <http://www.tensorflow.org/beta/>`_
or `PyTorch <http://pytorch.org/>`_.

* :func:`.get_backend`
* :func:`.set_backend`


Datatype
--------

Which datatype to use as the default for parameters.  Depending on your model,
you might have to set the default datatype to match the datatype of your data.

* :func:`.get_datatype`
* :func:`.set_datatype`


Samples
-------

Whether and how many samples to draw from parameter posterior distributions.
If ``None``, the maximum a posteriori estimate of each parameter will be used.
If an integer greater than 0, that many samples from each parameter's posterior
distribution will be used.

* :func:`.get_samples`
* :func:`.set_samples`


Flipout
-------

Whether to use `Flipout <https://arxiv.org/abs/1803.04386>`_ where possible.

* :func:`.get_flipout`
* :func:`.set_flipout`


Static posterior sampling
-------------------------

Whether or not to use static posterior sampling (i.e., take a random sample
from the posterior, but take the same random sample on repeated calls), and
control the UUID of the current static sampling regime.

* :func:`.get_static_sampling_uuid`
* :func:`.set_static_sampling_uuid`


Sampling context manager
------------------------

A context manager which controls how |Parameters| sample from their
variational distributions while inside the context manager.

* :class:`.Sampling`


"""


import uuid

__all__ = [
    "get_backend",
    "set_backend",
    "get_datatype",
    "set_datatype",
    "get_samples",
    "set_samples",
    "get_flipout",
    "set_flipout",
    "get_static_sampling_uuid",
    "set_static_sampling_uuid",
    "Sampling",
]


class _Settings:
    """Class to store ProbFlow global settings

    Attributes
    ----------
    _BACKEND : str {'tensorflow' or 'pytorch'}
        What backend to use
    _SAMPLES : |None| or int > 0
        How many samples to take from |Parameter| variational posteriors.
        If |None|, will use MAP estimates.
    _FLIPOUT : bool
        Whether to use flipout where possible
    _DATATYPE : tf.dtype or torch.dtype
        Default datatype to use for tensors
    _STATIC_SAMPLING_UUID : None or uuid.UUID
        UUID of the current static sampling regime
    """

    def __init__(self):
        self._BACKEND = "tensorflow"
        self._SAMPLES = None
        self._FLIPOUT = False
        self._DATATYPE = None
        self._STATIC_SAMPLING_UUID = None


# Global ProbFlow settings
__SETTINGS__ = _Settings()


[docs]def get_backend(): """Get which backend is currently being used. Returns ------- backend : str {'tensorflow' or 'pytorch'} The current backend """ return __SETTINGS__._BACKEND
[docs]def set_backend(backend): """Set which backend is currently being used. Parameters ---------- backend : str {'tensorflow' or 'pytorch'} The backend to use """ if isinstance(backend, str): if backend in ["tensorflow", "pytorch"]: __SETTINGS__._BACKEND = backend else: raise ValueError("backend must be either tensorflow or pytorch") else: raise TypeError("backend must be a string")
[docs]def get_datatype(): """Get the default datatype used for Tensors Returns ------- dtype : tf.dtype or torch.dtype The current default datatype """ if __SETTINGS__._DATATYPE is None: if get_backend() == "pytorch": import torch return torch.float32 else: import tensorflow as tf return tf.dtypes.float32 else: return __SETTINGS__._DATATYPE
[docs]def set_datatype(datatype): """Set the datatype to use for Tensors Parameters ---------- datatype : tf.dtype or torch.dtype The default datatype to use """ if get_backend() == "pytorch": import torch if datatype is None or isinstance(datatype, torch.dtype): __SETTINGS__._DATATYPE = datatype else: raise TypeError("datatype must be a torch.dtype") else: import tensorflow as tf if datatype is None or isinstance(datatype, tf.dtypes.DType): __SETTINGS__._DATATYPE = datatype else: raise TypeError("datatype must be a tf.dtypes.DType")
[docs]def get_samples(): """Get how many samples (if any) are being drawn from parameter posteriors Returns ------- n : None or int > 0 Number of samples (if any) to draw from parameters' posteriors. Default = None (ie, use the Maximum a posteriori estimate) """ return __SETTINGS__._SAMPLES
[docs]def set_samples(samples): """Set how many samples (if any) to draw from parameter posteriors Parameters ---------- samples : None or int > 0 Number of samples (if any) to draw from parameters' posteriors. """ if samples is not None and not isinstance(samples, int): raise TypeError("samples must be an int or None") elif isinstance(samples, int) and samples < 1: raise ValueError("samples must be positive") else: __SETTINGS__._SAMPLES = samples
[docs]def get_flipout(): """Get whether flipout is currently being used where possible. Returns ------- flipout : bool Whether flipout is currently being used where possible while sampling during training. """ return __SETTINGS__._FLIPOUT
[docs]def set_flipout(flipout): """Set whether to use flipout where possible while sampling during training Parameters ---------- flipout : bool Whether to use flipout where possible while sampling during training. """ if isinstance(flipout, bool): __SETTINGS__._FLIPOUT = flipout else: raise TypeError("flipout must be True or False")
[docs]def get_static_sampling_uuid(): """Get the current static sampling UUID""" return __SETTINGS__._STATIC_SAMPLING_UUID
[docs]def set_static_sampling_uuid(uuid_value): """Set the current static sampling UUID""" if uuid_value is None or isinstance(uuid_value, uuid.UUID): __SETTINGS__._STATIC_SAMPLING_UUID = uuid_value else: raise TypeError("must be a uuid or None")
[docs]class Sampling: """Use sampling while within this context manager. Keyword Arguments ----------------- n : None or int > 0 Number of samples (if any) to draw from parameters' posteriors. Default = 1 flipout : bool Whether to use flipout where possible while sampling during training. Default = False Example ------- To use maximum a posteriori estimates of the parameter values, don't use the sampling context manager: .. code-block:: pycon >>> import probflow as pf >>> param = pf.Parameter() >>> param() [0.07226744] >>> param() # MAP estimate is always the same [0.07226744] To use a single sample, use the sampling context manager with ``n=1``: .. code-block:: pycon >>> with pf.Sampling(n=1): >>> param() [-2.2228503] >>> with pf.Sampling(n=1): >>> param() #samples are different [1.3473024] To use multiple samples, use the sampling context manager and set the number of samples to take with the ``n`` keyword argument: .. code-block:: pycon >>> with pf.Sampling(n=3): >>> param() [[ 0.10457394] [ 0.14018342] [-1.8649881 ]] >>> with pf.Sampling(n=5): >>> param() [[ 2.1035051] [-2.641631 ] [-2.9091313] [ 3.5294306] [ 1.6596333]] To use static samples - that is, to always return the same samples while in the same context manager - use the sampling context manager with the ``static`` keyword argument set to ``True``: .. code-block:: pycon >>> with pf.Sampling(static=True): >>> param() [ 0.10457394] >>> param() # repeated samples yield the same value [ 0.10457394] >>> with pf.Sampling(static=True): >>> param() # under a new context manager they yield new samples [-2.641631] >>> param() # but remain the same while under the same context [-2.641631] """ def __init__(self, n=None, flipout=None, static=None): self._n = n self._flipout = flipout self._static = static def __enter__(self): """Begin sampling.""" if self._n is not None: set_samples(self._n) if self._flipout is not None: set_flipout(self._flipout) if self._static is not None: set_static_sampling_uuid(uuid.uuid4()) def __exit__(self, _type, _val, _tb): """End sampling and reset sampling settings to defaults""" if self._n is not None: set_samples(None) if self._flipout is not None: set_flipout(False) if self._static is not None: set_static_sampling_uuid(None)