def test_sampling():
    """Tests the Sampling context manager"""

    # Defaults before sampling
    assert settings.get_backend() == 'tensorflow'
    assert settings.get_samples() is None
    assert settings.get_flipout() is False

    # Default should be samples=1 and flipout=False
    with settings.Sampling():
        assert settings.get_samples() == 1
        assert settings.get_flipout() is False

    # Should return to defaults after sampling
    assert settings.get_backend() == 'tensorflow'
    assert settings.get_samples() is None
    assert settings.get_flipout() is False

    # Should be able to set samples and flipout via kwargs
    with settings.Sampling(n=100, flipout=True):
        assert settings.get_samples() == 100
        assert settings.get_flipout() is True

    # Again should return to defaults after __exit__
    assert settings.get_backend() == 'tensorflow'
    assert settings.get_samples() is None
    assert settings.get_flipout() is False
示例#2
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         raise NotImplementedError
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Deterministic(self.loc)
示例#3
0
文件: ops.py 项目: hanxirui/probflow
def kl_divergence(P, Q):
    """Compute the Kullback–Leibler divergence between two distributions.

    Parameters
    ----------
    P : |tfp.Distribution| or |torch.Distribution|
        The first distribution
    Q : |tfp.Distribution| or |torch.Distribution|
        The second distribution

    Returns
    -------
    kld : Tensor
        The Kullback–Leibler divergence between P and Q (KL(P || Q))
    """

    # Get the backend distribution if needed
    if isinstance(P, BaseDistribution):
        P = P()
    if isinstance(Q, BaseDistribution):
        Q = Q()

    # Compute KL divergence with the backend
    if get_backend() == 'pytorch':
        import torch
        return torch.distributions.kl.kl_divergence(P, Q)
    else:
        import tensorflow_probability as tfp
        return tfp.distributions.kl_divergence(P, Q)
示例#4
0
    def __call__(self, x):
        """Perform the forward pass"""

        # Using the Flipout estimator
        if get_flipout():
        
            # With PyTorch
            if get_backend() == 'pytorch':
                raise NotImplementedError

            # With Tensorflow
            else:

                import tensorflow as tf
                import tensorflow_probability as tfp

                # Flipout-estimated weight samples
                s = tfp.python.math.random_rademacher(tf.shape(x))
                r = tfp.python.math.random_rademacher([x.shape[0],self.d_out])
                norm_samples = tf.random.normal([self.d_in, self.d_out])
                w_samples = self.weights.variables['scale'] * norm_samples
                w_noise = r*((x*s) @ w_samples)
                w_outputs = x @ self.weights.variables['loc'] + w_noise
                
                # Flipout-estimated bias samples
                r = tfp.python.math.random_rademacher([x.shape[0],self.d_out])
                norm_samples = tf.random.normal([self.d_out])
                b_samples = self.bias.variables['scale'] * norm_samples
                b_outputs = self.bias.variables['loc'] + r*b_samples
                
                return w_outputs + b_outputs
        
        # Without Flipout
        else:
            return x @ self.weights() + self.bias()
示例#5
0
文件: ops.py 项目: hanxirui/probflow
def sigmoid(val):
    """Sigmoid function."""
    if get_backend() == 'pytorch':
        import torch
        return torch.nn.Sigmoid()(val)
    else:
        import tensorflow as tf
        return tf.math.sigmoid(val)
示例#6
0
文件: ops.py 项目: hanxirui/probflow
def gather(vals, inds, axis=0):
    """Gather values by index"""
    if get_backend() == 'pytorch':
        import torch
        return torch.index_select(vals, axis, inds)
    else:
        import tensorflow as tf
        return tf.gather(vals, inds, axis=axis)
示例#7
0
def expand_dims(val, axis):
    """Add a singular dimension to a Tensor"""
    if get_backend() == 'pytorch':
        import torch
        return torch.unsqueeze(val, axis)
    else:
        import tensorflow as tf
        return tf.expand_dims(val, axis)
示例#8
0
文件: ops.py 项目: hanxirui/probflow
def softplus(val):
    """Linear rectification."""
    if get_backend() == 'pytorch':
        import torch
        return torch.nn.Softplus()(val)
    else:
        import tensorflow as tf
        return tf.math.softplus(val)
示例#9
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         return tod.dirichlet.Dirichlet(self.concentration)
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Dirichlet(self.concentration)
示例#10
0
文件: ops.py 项目: hanxirui/probflow
def exp(val):
    """The natural exponent."""
    if get_backend() == 'pytorch':
        import torch
        return torch.exp(val)
    else:
        import tensorflow as tf
        return tf.exp(val)
示例#11
0
文件: ops.py 项目: hanxirui/probflow
def squeeze(val):
    """Remove singleton dimensions"""
    if get_backend() == 'pytorch':
        import torch
        return torch.squeeze(val)
    else:
        import tensorflow as tf
        return tf.squeeze(val)
示例#12
0
文件: ops.py 项目: hanxirui/probflow
def zeros(shape):
    """Tensor full of zeros."""
    if get_backend() == 'pytorch':
        import torch
        return torch.zeros(shape)
    else:
        import tensorflow as tf
        return tf.zeros(shape)
示例#13
0
文件: ops.py 项目: hanxirui/probflow
def std(val, axis=-1):
    """The uncorrected sample standard deviation."""
    if get_backend() == 'pytorch':
        import torch
        return torch.std(val, dim=axis)
    else:
        import tensorflow as tf
        return tf.math.reduce_std(val, axis=axis)
示例#14
0
文件: ops.py 项目: hanxirui/probflow
def round(val):
    """Round to the closest integer"""
    if get_backend() == 'pytorch':
        import torch
        return torch.round(val)
    else:
        import tensorflow as tf
        return tf.math.round(val)
示例#15
0
文件: ops.py 项目: hanxirui/probflow
def mean(val, axis=-1):
    """The mean."""
    if get_backend() == 'pytorch':
        import torch
        return torch.mean(val, dim=axis)
    else:
        import tensorflow as tf
        return tf.reduce_mean(val, axis=axis)
示例#16
0
文件: ops.py 项目: hanxirui/probflow
def prod(val, axis=-1):
    """The product."""
    if get_backend() == 'pytorch':
        import torch
        return torch.prod(val, dim=axis)
    else:
        import tensorflow as tf
        return tf.reduce_prod(val, axis=axis)
示例#17
0
文件: ops.py 项目: hanxirui/probflow
def eye(dims):
    """Identity matrix."""
    if get_backend() == 'pytorch':
        import torch
        return torch.eye(dims)
    else:
        import tensorflow as tf
        return tf.eye(dims)
示例#18
0
文件: ops.py 项目: hanxirui/probflow
def cat(vals, axis=0):
    """Concatenate tensors"""
    if get_backend() == 'pytorch':
        import torch
        return torch.cat(vals, dim=axis)
    else:
        import tensorflow as tf
        return tf.concat(vals, axis=axis)
示例#19
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         return tod.normal.Normal(self['loc'], self['scale'])
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Normal(self['loc'], self['scale'])
示例#20
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         return tod.cauchy.Cauchy(self.loc, self.scale)
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Cauchy(self.loc, self.scale)
示例#21
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         return tod.gamma.Gamma(self['concentration'], self['rate'])
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Gamma(self['concentration'], self['rate'])
示例#22
0
文件: ops.py 项目: hanxirui/probflow
def abs(val):
    """Absolute value"""
    if get_backend() == 'pytorch':
        import torch
        return torch.abs(val)
    else:
        import tensorflow as tf
        return tf.math.abs(val)
示例#23
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         TorchDeterministic = get_TorchDeterministic()
         return TorchDeterministic(self['loc'])
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Deterministic(self['loc'])
示例#24
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         raise NotImplementedError
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.InverseGamma(self.concentration, self.scale)
示例#25
0
文件: ops.py 项目: hanxirui/probflow
def square(val):
    """Power of 2"""
    if get_backend() == 'pytorch':
        import torch
        return val**2
    else:
        import tensorflow as tf
        return tf.math.square(val)
示例#26
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         return tod.poisson.Poisson(self.rate)
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Poisson(self.rate)
示例#27
0
文件: ops.py 项目: hanxirui/probflow
def sqrt(val):
    """The square root."""
    if get_backend() == 'pytorch':
        import torch
        return torch.sqrt(val)
    else:
        import tensorflow as tf
        return tf.sqrt(val)
示例#28
0
文件: ops.py 项目: hanxirui/probflow
def relu(val):
    """Linear rectification."""
    if get_backend() == 'pytorch':
        import torch
        return torch.nn.ReLU()(val)
    else:
        import tensorflow as tf
        return tf.nn.relu(val)
示例#29
0
def sum(val, axis=-1):
    """The sum."""
    if get_backend() == 'pytorch':
        import torch
        return torch.sum(val, dim=axis)
    else:
        import tensorflow as tf
        return tf.reduce_sum(val, axis=axis)
示例#30
0
 def __call__(self):
     """Get the distribution object from the backend"""
     if get_backend() == 'pytorch':
         import torch.distributions as tod
         return tod.categorical.Categorical(logits=self.logits,
                                            probs=self.probs)
     else:
         from tensorflow_probability import distributions as tfd
         return tfd.Categorical(logits=self.logits, probs=self.probs)