Exemple #1
0
def test_bernoulliprobs_sample(batch_shape, sample_inputs):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

    probs = Tensor(torch.rand(batch_shape), inputs)
    funsor_dist = dist.Bernoulli(probs=probs)

    _check_sample(funsor_dist,
                  sample_inputs,
                  inputs,
                  atol=5e-2,
                  num_samples=100000)
Exemple #2
0
    def loss_function(data, subsample_scale):
        # Lazily sample from the guide.
        loc, scale = encode(data)
        q = funsor.Independent(dist.Normal(loc['i'], scale['i'], value='z_i'),
                               'z', 'i', 'z_i')

        # Evaluate the model likelihood at the lazy value z.
        probs = decode('z')
        p = dist.Bernoulli(probs['x', 'y'], value=data['x', 'y'])
        p = p.reduce(ops.add, {'x', 'y'})

        # Construct an elbo. This is where sampling happens.
        elbo = funsor.Integrate(q, p - q, frozenset(['z']))
        elbo = elbo.reduce(ops.add, 'batch') * subsample_scale
        loss = -elbo
        return loss
def test_bernoulli_logits_density(batch_shape, syntax):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

    @funsor.torch.function(reals(), reals(), reals())
    def bernoulli(logits, value):
        return torch.distributions.Bernoulli(logits=logits).log_prob(value)

    check_funsor(bernoulli, {'logits': reals(), 'value': reals()}, reals())

    logits = Tensor(torch.rand(batch_shape), inputs)
    value = Tensor(torch.rand(batch_shape).round(), inputs)
    expected = bernoulli(logits, value)
    check_funsor(expected, inputs, reals())

    d = Variable('value', reals())
    if syntax == 'eager':
        actual = dist.BernoulliLogits(logits, value)
    elif syntax == 'lazy':
        actual = dist.BernoulliLogits(logits, d)(value=value)
    elif syntax == 'generic':
        actual = dist.Bernoulli(logits=logits)(value=value)
    check_funsor(actual, inputs, reals())
    assert_close(actual, expected)