from typing import Dict, List
import probflow.utils.ops as O
from probflow.utils.base import BaseModule, BaseParameter
from probflow.utils.io import dump, dumps
[docs]class Module(BaseModule):
r"""Abstract base class for Modules.
TODO
"""
def _params(self, obj):
"""Recursively search for |Parameters| contained within an object"""
if isinstance(obj, BaseParameter):
return [obj]
elif isinstance(obj, BaseModule):
return obj.parameters
elif isinstance(obj, list):
return self._list_params(obj)
elif isinstance(obj, dict):
return self._dict_params(obj)
else:
return []
def _list_params(self, the_list: List):
"""Recursively search for |Parameters| contained in a list"""
return [p for e in the_list for p in self._params(e)]
def _dict_params(self, the_dict: Dict):
"""Recursively search for |Parameters| contained in a dict"""
return [p for _, e in the_dict.items() for p in self._params(e)]
@property
def parameters(self):
"""A list of |Parameters| in this |Module| and its sub-Modules."""
return [p for _, a in vars(self).items() for p in self._params(a)]
@property
def modules(self):
"""A list of sub-Modules in this |Module|, including itself."""
return [
m
for a in vars(self).values()
if isinstance(a, BaseModule)
for m in a.modules
] + [self]
@property
def trainable_variables(self):
"""A list of trainable backend variables within this |Module|"""
return [v for p in self.parameters for v in p.trainable_variables]
# TODO: look for variables NOT in parameters too
# so users can mix-n-match tf.Variables and pf.Parameters in modules
@property
def n_parameters(self):
"""Get the number of independent parameters of this module"""
return sum([p.n_parameters for p in self.parameters])
@property
def n_variables(self):
"""Get the number of underlying variables in this module"""
return sum([p.n_variables for p in self.parameters])
[docs] def bayesian_update(self):
"""Perform a Bayesian update of all |Parameters| in this module. Sets
the prior to the current variational posterior for all parameters.
"""
for p in self.parameters:
p.bayesian_update()
[docs] def kl_loss(self):
"""Compute the sum of the Kullback-Leibler divergences between
priors and their variational posteriors for all |Parameters| in this
|Module| and its sub-Modules."""
return sum([p.kl_loss() for p in self.parameters])
[docs] def kl_loss_batch(self):
"""Compute the sum of additional Kullback-Leibler divergences due to
data in this batch"""
return sum([e for m in self.modules for e in m._kl_losses])
[docs] def reset_kl_loss(self):
"""Reset additional loss due to KL divergences"""
for m in self.modules:
m._kl_losses = []
[docs] def add_kl_loss(self, loss, d2=None):
"""Add additional loss due to KL divergences."""
if d2 is None:
self._kl_losses += [O.sum(loss, axis=None)]
else:
self._kl_losses += [O.sum(O.kl_divergence(loss, d2), axis=None)]
[docs] def dumps(self):
"""Serialize module object to bytes"""
return dumps(self)
[docs] def save(self, filename: str):
"""Save module object to file
Parameters
----------
filename : str
Filename for file to which to save this object
"""
dump(self, filename)