Source code for probflow.utils.validation

"""
The utils.validation module contains functions for checking that inputs have
the correct type.

* :func:`.ensure_tensor_like`

----------

"""


import numpy as np

from probflow.utils.base import BaseParameter
from probflow.utils.settings import get_backend


[docs]def ensure_tensor_like(obj, name): """Determine whether an object can be cast to a Tensor""" # Check for non-backend-dependent types if isinstance(obj, (int, float, np.ndarray, list)): return # Check for backend-dependent types if get_backend() == "pytorch": import torch tensor_types = (torch.Tensor, BaseParameter) else: import tensorflow as tf tensor_types = (tf.Tensor, tf.Variable, BaseParameter) if not isinstance(obj, tensor_types): raise TypeError(name + " must be Tensor-like")