def test_sampling():
    """Tests the Sampling context manager"""

    # Defaults before sampling
    assert settings.get_backend() == 'tensorflow'
    assert settings.get_samples() is None
    assert settings.get_flipout() is False

    # Default should be samples=1 and flipout=False
    with settings.Sampling():
        assert settings.get_samples() == 1
        assert settings.get_flipout() is False

    # Should return to defaults after sampling
    assert settings.get_backend() == 'tensorflow'
    assert settings.get_samples() is None
    assert settings.get_flipout() is False

    # Should be able to set samples and flipout via kwargs
    with settings.Sampling(n=100, flipout=True):
        assert settings.get_samples() == 100
        assert settings.get_flipout() is True

    # Again should return to defaults after __exit__
    assert settings.get_backend() == 'tensorflow'
    assert settings.get_samples() is None
    assert settings.get_flipout() is False
Exemple #2
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         raise NotImplementedError
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Deterministic(self.loc)
Exemple #3
0
def kl_divergence(P, Q):
    """Compute the Kullback–Leibler divergence between two distributions.

    Parameters
    ----------
    P : |tfp.Distribution| or |torch.Distribution|
        The first distribution
    Q : |tfp.Distribution| or |torch.Distribution|
        The second distribution

    Returns
    -------
    kld : Tensor
        The Kullback–Leibler divergence between P and Q (KL(P || Q))
    """

    # Get the backend distribution if needed
    if isinstance(P, BaseDistribution):
        P = P()
    if isinstance(Q, BaseDistribution):
        Q = Q()

    # Compute KL divergence with the backend
    if get_backend() == 'pytorch':
        import torch
        return torch.distributions.kl.kl_divergence(P, Q)
    else:
        import tensorflow_probability as tfp
        return tfp.distributions.kl_divergence(P, Q)
Exemple #4
0
    def __call__(self, x):
        """Perform the forward pass"""

        # Using the Flipout estimator
        if get_flipout():
        
            # With PyTorch
            if get_backend() == 'pytorch':
                raise NotImplementedError

            # With Tensorflow
            else:

                import tensorflow as tf
                import tensorflow_probability as tfp

                # Flipout-estimated weight samples
                s = tfp.python.math.random_rademacher(tf.shape(x))
                r = tfp.python.math.random_rademacher([x.shape[0],self.d_out])
                norm_samples = tf.random.normal([self.d_in, self.d_out])
                w_samples = self.weights.variables['scale'] * norm_samples
                w_noise = r*((x*s) @ w_samples)
                w_outputs = x @ self.weights.variables['loc'] + w_noise
                
                # Flipout-estimated bias samples
                r = tfp.python.math.random_rademacher([x.shape[0],self.d_out])
                norm_samples = tf.random.normal([self.d_out])
                b_samples = self.bias.variables['scale'] * norm_samples
                b_outputs = self.bias.variables['loc'] + r*b_samples
                
                return w_outputs + b_outputs
        
        # Without Flipout
        else:
            return x @ self.weights() + self.bias()
Exemple #5
0
def sigmoid(val):
    """Sigmoid function."""
    if get_backend() == 'pytorch':
        import torch
        return torch.nn.Sigmoid()(val)
    else:
        import tensorflow as tf
        return tf.math.sigmoid(val)
Exemple #6
0
def gather(vals, inds, axis=0):
    """Gather values by index"""
    if get_backend() == 'pytorch':
        import torch
        return torch.index_select(vals, axis, inds)
    else:
        import tensorflow as tf
        return tf.gather(vals, inds, axis=axis)
Exemple #7
0
def expand_dims(val, axis):
    """Add a singular dimension to a Tensor"""
    if get_backend() == 'pytorch':
        import torch
        return torch.unsqueeze(val, axis)
    else:
        import tensorflow as tf
        return tf.expand_dims(val, axis)
Exemple #8
0
def softplus(val):
    """Linear rectification."""
    if get_backend() == 'pytorch':
        import torch
        return torch.nn.Softplus()(val)
    else:
        import tensorflow as tf
        return tf.math.softplus(val)
Exemple #9
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         return tod.dirichlet.Dirichlet(self.concentration)
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Dirichlet(self.concentration)
Exemple #10
0
def exp(val):
    """The natural exponent."""
    if get_backend() == 'pytorch':
        import torch
        return torch.exp(val)
    else:
        import tensorflow as tf
        return tf.exp(val)
Exemple #11
0
def squeeze(val):
    """Remove singleton dimensions"""
    if get_backend() == 'pytorch':
        import torch
        return torch.squeeze(val)
    else:
        import tensorflow as tf
        return tf.squeeze(val)
Exemple #12
0
def zeros(shape):
    """Tensor full of zeros."""
    if get_backend() == 'pytorch':
        import torch
        return torch.zeros(shape)
    else:
        import tensorflow as tf
        return tf.zeros(shape)
Exemple #13
0
def std(val, axis=-1):
    """The uncorrected sample standard deviation."""
    if get_backend() == 'pytorch':
        import torch
        return torch.std(val, dim=axis)
    else:
        import tensorflow as tf
        return tf.math.reduce_std(val, axis=axis)
Exemple #14
0
def round(val):
    """Round to the closest integer"""
    if get_backend() == 'pytorch':
        import torch
        return torch.round(val)
    else:
        import tensorflow as tf
        return tf.math.round(val)
Exemple #15
0
def mean(val, axis=-1):
    """The mean."""
    if get_backend() == 'pytorch':
        import torch
        return torch.mean(val, dim=axis)
    else:
        import tensorflow as tf
        return tf.reduce_mean(val, axis=axis)
Exemple #16
0
def prod(val, axis=-1):
    """The product."""
    if get_backend() == 'pytorch':
        import torch
        return torch.prod(val, dim=axis)
    else:
        import tensorflow as tf
        return tf.reduce_prod(val, axis=axis)
Exemple #17
0
def eye(dims):
    """Identity matrix."""
    if get_backend() == 'pytorch':
        import torch
        return torch.eye(dims)
    else:
        import tensorflow as tf
        return tf.eye(dims)
Exemple #18
0
def cat(vals, axis=0):
    """Concatenate tensors"""
    if get_backend() == 'pytorch':
        import torch
        return torch.cat(vals, dim=axis)
    else:
        import tensorflow as tf
        return tf.concat(vals, axis=axis)
Exemple #19
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         return tod.normal.Normal(self['loc'], self['scale'])
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Normal(self['loc'], self['scale'])
Exemple #20
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         return tod.cauchy.Cauchy(self.loc, self.scale)
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Cauchy(self.loc, self.scale)
Exemple #21
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         return tod.gamma.Gamma(self['concentration'], self['rate'])
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Gamma(self['concentration'], self['rate'])
Exemple #22
0
def abs(val):
    """Absolute value"""
    if get_backend() == 'pytorch':
        import torch
        return torch.abs(val)
    else:
        import tensorflow as tf
        return tf.math.abs(val)
Exemple #23
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         TorchDeterministic = get_TorchDeterministic()
         return TorchDeterministic(self['loc'])
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Deterministic(self['loc'])
Exemple #24
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         raise NotImplementedError
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.InverseGamma(self.concentration, self.scale)
Exemple #25
0
def square(val):
    """Power of 2"""
    if get_backend() == 'pytorch':
        import torch
        return val**2
    else:
        import tensorflow as tf
        return tf.math.square(val)
Exemple #26
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         return tod.poisson.Poisson(self.rate)
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Poisson(self.rate)
Exemple #27
0
def sqrt(val):
    """The square root."""
    if get_backend() == 'pytorch':
        import torch
        return torch.sqrt(val)
    else:
        import tensorflow as tf
        return tf.sqrt(val)
Exemple #28
0
def relu(val):
    """Linear rectification."""
    if get_backend() == 'pytorch':
        import torch
        return torch.nn.ReLU()(val)
    else:
        import tensorflow as tf
        return tf.nn.relu(val)
Exemple #29
0
def sum(val, axis=-1):
    """The sum."""
    if get_backend() == 'pytorch':
        import torch
        return torch.sum(val, dim=axis)
    else:
        import tensorflow as tf
        return tf.reduce_sum(val, axis=axis)
Exemple #30
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         return tod.categorical.Categorical(logits=self.logits,
                                            probs=self.probs)
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Categorical(logits=self.logits, probs=self.probs)