def test_sampling():
    """Tests the Sampling context manager"""

    # Defaults before sampling
    assert settings.get_backend() == "tensorflow"
    assert settings.get_samples() is None
    assert settings.get_flipout() is False

    # Default should be samples=1 and flipout=False
    with settings.Sampling():
        assert settings.get_samples() == 1
        assert settings.get_flipout() is False

    # Should return to defaults after sampling
    assert settings.get_backend() == "tensorflow"
    assert settings.get_samples() is None
    assert settings.get_flipout() is False

    # Should be able to set samples and flipout via kwargs
    with settings.Sampling(n=100, flipout=True):
        assert settings.get_samples() == 100
        assert settings.get_flipout() is True

    # Again should return to defaults after __exit__
    assert settings.get_backend() == "tensorflow"
    assert settings.get_samples() is None
    assert settings.get_flipout() is False
Example #2
0
    def __init__(self, distributions, logits=None, probs=None):

        # Check input
        if logits is None and probs is None:
            raise ValueError("must pass either logits or probs")
        if probs is not None:
            ensure_tensor_like(probs, "probs")
        if logits is not None:
            ensure_tensor_like(logits, "logits")

        # Distributions should be a pf, tf, or pt distribution
        if not isinstance(distributions, BaseDistribution):
            if get_backend() == "pytorch":
                import torch.distributions as tod

                if not isinstance(distributions, tod.Distribution):
                    raise TypeError(
                        "requires either a ProbFlow or PyTorch distribution")
            else:
                from tensorflow_probability import distributions as tfd

                if not isinstance(distributions, tfd.Distribution):
                    raise TypeError(
                        "requires either a ProbFlow or TensorFlow distribution"
                    )

        # Store args
        self.distributions = distributions
        self.logits = logits
        self.probs = probs
Example #3
0
def to_default_dtype(x):
    if get_backend() == "pytorch":
        return x.type(get_datatype())
    else:
        import tensorflow as tf

        return tf.cast(x, get_datatype())
Example #4
0
 def train_step(self, x_data, y_data):
     """Perform one training step"""
     elbo = self._train_fn(x_data, y_data)
     if get_backend() == "pytorch":
         self._current_elbo += elbo.detach().numpy()
     else:
         self._current_elbo += elbo.numpy()
Example #5
0
def log_cholesky_transform(x):
    r"""Perform the log cholesky transform on a vector of values.

    This turns a vector of :math:`\frac{N(N+1)}{2}` unconstrained values into a
    valid :math:`N \times N` covariance matrix.


    References
    ----------

    - Jose C. Pinheiro & Douglas M. Bates.  `Unconstrained Parameterizations
      for Variance-Covariance Matrices
      <https://dx.doi.org/10.1007/BF00140873>`_ *Statistics and Computing*,
      1996.
    """

    if get_backend() == "pytorch":
        import numpy as np
        import torch

        N = int((np.sqrt(1 + 8 * torch.numel(x)) - 1) / 2)
        E = torch.zeros((N, N), dtype=get_datatype())
        tril_ix = torch.tril_indices(row=N, col=N, offset=0)
        E[..., tril_ix[0], tril_ix[1]] = x
        E[..., range(N), range(N)] = torch.exp(torch.diagonal(E))
        return E @ torch.transpose(E, -1, -2)
    else:
        import tensorflow as tf
        import tensorflow_probability as tfp

        E = tfp.math.fill_triangular(x)
        E = tf.linalg.set_diag(E, tf.exp(tf.linalg.tensor_diag_part(E)))
        return E @ tf.transpose(E)
Example #6
0
    def __call__(self):
        """Get the distribution object from the backend"""
        if get_backend() == "pytorch":
            # import torch.distributions as tod
            raise NotImplementedError
        else:
            import tensorflow as tf
            from tensorflow_probability import distributions as tfd

            # Convert to tensorflow distributions if probflow distributions
            if isinstance(self.distributions, BaseDistribution):
                self.distributions = self.distributions()

            # Broadcast probs/logits
            shape = self.distributions.batch_shape
            args = {"logits": None, "probs": None}
            if self.logits is not None:
                args["logits"] = tf.broadcast_to(self["logits"], shape)
            else:
                args["probs"] = tf.broadcast_to(self["probs"], shape)

            # Return TFP distribution object
            return tfd.MixtureSameFamily(
                tfd.Categorical(**args), self.distributions
            )
Example #7
0
def kl_divergence(P, Q):
    """Compute the Kullback–Leibler divergence between two distributions.

    Parameters
    ----------
    P : |tfp.Distribution| or |torch.Distribution|
        The first distribution
    Q : |tfp.Distribution| or |torch.Distribution|
        The second distribution

    Returns
    -------
    kld : Tensor
        The Kullback–Leibler divergence between P and Q (KL(P || Q))
    """

    # Get the backend distribution if needed
    if isinstance(P, BaseDistribution):
        P = P()
    if isinstance(Q, BaseDistribution):
        Q = Q()

    # Compute KL divergence with the backend
    if get_backend() == "pytorch":
        import torch

        return torch.distributions.kl.kl_divergence(P, Q)
    else:
        import tensorflow_probability as tfp

        return tfp.distributions.kl_divergence(P, Q)
Example #8
0
    def __call__(self, x):
        """Perform the forward pass"""

        # Using the Flipout estimator
        if get_flipout():

            # With PyTorch
            if get_backend() == "pytorch":
                raise NotImplementedError

            # With Tensorflow
            else:

                import tensorflow as tf
                import tensorflow_probability as tfp

                # Flipout-estimated weight samples
                s = tfp.python.math.random_rademacher(tf.shape(x))
                r = tfp.python.math.random_rademacher([x.shape[0], self.d_out])
                norm_samples = tf.random.normal([self.d_in, self.d_out])
                w_samples = self.weights.variables["scale"] * norm_samples
                w_noise = r * ((x * s) @ w_samples)
                w_outputs = x @ self.weights.variables["loc"] + w_noise

                # Flipout-estimated bias samples
                r = tfp.python.math.random_rademacher([x.shape[0], self.d_out])
                norm_samples = tf.random.normal([self.d_out])
                b_samples = self.bias.variables["scale"] * norm_samples
                b_outputs = self.bias.variables["loc"] + r * b_samples

                return w_outputs + b_outputs

        # Without Flipout
        else:
            return x @ self.weights() + self.bias()
Example #9
0
def square(val):
    """Power of 2"""
    if get_backend() == "pytorch":
        return val**2
    else:
        import tensorflow as tf

        return tf.math.square(val)
Example #10
0
def copy_tensor(x):
    """Copy a tensor, detaching it from the gradient/backend/etc/etc"""
    if get_backend() == "pytorch":
        return x.detach().clone()
    else:
        import tensorflow as tf

        return tf.identity(x)
Example #11
0
 def set_learning_rate(self, lr):
     """Set the learning rate used by this model's optimizer"""
     if not isinstance(lr, float):
         raise TypeError("lr must be a float")
     else:
         self._learning_rate = lr
     if get_backend() == "pytorch":
         for g in self._optimizer.param_groups:
             g["lr"] = self._learning_rate
Example #12
0
def gather(vals, inds, axis=0):
    """Gather values by index"""
    if get_backend() == "pytorch":
        import torch

        return torch.index_select(vals, axis, inds)
    else:
        import tensorflow as tf

        return tf.gather(vals, inds, axis=axis)
Example #13
0
def sigmoid(val):
    """Sigmoid function."""
    if get_backend() == "pytorch":
        import torch

        return torch.nn.Sigmoid()(val)
    else:
        import tensorflow as tf

        return tf.math.sigmoid(val)
Example #14
0
def softplus(val):
    """Linear rectification."""
    if get_backend() == "pytorch":
        import torch

        return torch.nn.Softplus()(val)
    else:
        import tensorflow as tf

        return tf.math.softplus(val)
Example #15
0
def relu(val):
    """Linear rectification."""
    if get_backend() == "pytorch":
        import torch

        return torch.nn.ReLU()(val)
    else:
        import tensorflow as tf

        return tf.nn.relu(val)
Example #16
0
def exp(val):
    """The natural exponent."""
    if get_backend() == "pytorch":
        import torch

        return torch.exp(val)
    else:
        import tensorflow as tf

        return tf.exp(val)
Example #17
0
def sqrt(val):
    """The square root."""
    if get_backend() == "pytorch":
        import torch

        return torch.sqrt(val)
    else:
        import tensorflow as tf

        return tf.sqrt(val)
Example #18
0
def abs(val):
    """Absolute value"""
    if get_backend() == "pytorch":
        import torch

        return torch.abs(val)
    else:
        import tensorflow as tf

        return tf.math.abs(val)
Example #19
0
def cat(vals, axis=0):
    """Concatenate tensors"""
    if get_backend() == "pytorch":
        import torch

        return torch.cat(vals, dim=axis)
    else:
        import tensorflow as tf

        return tf.concat(vals, axis=axis)
Example #20
0
def prod(val, axis=-1):
    """The product."""
    if get_backend() == "pytorch":
        import torch

        return torch.prod(val, dim=axis)
    else:
        import tensorflow as tf

        return tf.reduce_prod(val, axis=axis)
Example #21
0
def eye(dims):
    """Identity matrix."""
    if get_backend() == "pytorch":
        import torch

        return torch.eye(dims)
    else:
        import tensorflow as tf

        return tf.eye(dims)
Example #22
0
def zeros(shape):
    """Tensor full of zeros."""
    if get_backend() == "pytorch":
        import torch

        return torch.zeros(shape)
    else:
        import tensorflow as tf

        return tf.zeros(shape)
Example #23
0
def squeeze(val):
    """Remove singleton dimensions"""
    if get_backend() == "pytorch":
        import torch

        return torch.squeeze(val)
    else:
        import tensorflow as tf

        return tf.squeeze(val)
Example #24
0
def test_backend():
    """Tests setting and getting the backend"""

    # Default should be tensorflow
    assert settings.get_backend() == "tensorflow"

    # Should be able to change to pytorch and back
    settings.set_backend("pytorch")
    assert settings.get_backend() == "pytorch"
    settings.set_backend("tensorflow")
    assert settings.get_backend() == "tensorflow"

    # But not anything else
    with pytest.raises(ValueError):
        settings.set_backend("lalala")

    # And it has to be a str
    with pytest.raises(TypeError):
        settings.set_backend(1)
Example #25
0
def new_variable(initial_values):
    """Get a new variable with the current backend, and initialize it"""
    if get_backend() == "pytorch":
        import torch

        return torch.nn.Parameter(initial_values)
    else:
        import tensorflow as tf

        return tf.Variable(initial_values)
Example #26
0
def reshape(x, new_shape):
    """Reshape a tensor"""
    if get_backend() == "pytorch":
        import torch

        return torch.reshape(x, tuple(new_shape))
    else:
        import tensorflow as tf

        return tf.reshape(x, new_shape)
Example #27
0
def mean(val, axis=-1):
    """The mean."""
    if get_backend() == "pytorch":
        import torch

        return torch.mean(val, dim=axis)
    else:
        import tensorflow as tf

        return tf.reduce_mean(val, axis=axis)
Example #28
0
def round(val):
    """Round to the closest integer"""
    if get_backend() == "pytorch":
        import torch

        return torch.round(val)
    else:
        import tensorflow as tf

        return tf.math.round(val)
Example #29
0
    def __call__(self):
        """Get the distribution object from the backend"""
        if get_backend() == "pytorch":
            import torch.distributions as tod

            return tod.dirichlet.Dirichlet(self["concentration"])
        else:
            from tensorflow_probability import distributions as tfd

            return tfd.Dirichlet(self["concentration"])
Example #30
0
def std(val, axis=-1):
    """The uncorrected sample standard deviation."""
    if get_backend() == "pytorch":
        import torch

        return torch.std(val, dim=axis)
    else:
        import tensorflow as tf

        return tf.math.reduce_std(val, axis=axis)