def test_bernoulliprobs_sample(batch_shape, sample_inputs): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) probs = Tensor(torch.rand(batch_shape), inputs) funsor_dist = dist.Bernoulli(probs=probs) _check_sample(funsor_dist, sample_inputs, inputs, atol=5e-2, num_samples=100000)
def loss_function(data, subsample_scale): # Lazily sample from the guide. loc, scale = encode(data) q = funsor.Independent(dist.Normal(loc['i'], scale['i'], value='z_i'), 'z', 'i', 'z_i') # Evaluate the model likelihood at the lazy value z. probs = decode('z') p = dist.Bernoulli(probs['x', 'y'], value=data['x', 'y']) p = p.reduce(ops.add, {'x', 'y'}) # Construct an elbo. This is where sampling happens. elbo = funsor.Integrate(q, p - q, frozenset(['z'])) elbo = elbo.reduce(ops.add, 'batch') * subsample_scale loss = -elbo return loss
def test_bernoulli_logits_density(batch_shape, syntax): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) @funsor.torch.function(reals(), reals(), reals()) def bernoulli(logits, value): return torch.distributions.Bernoulli(logits=logits).log_prob(value) check_funsor(bernoulli, {'logits': reals(), 'value': reals()}, reals()) logits = Tensor(torch.rand(batch_shape), inputs) value = Tensor(torch.rand(batch_shape).round(), inputs) expected = bernoulli(logits, value) check_funsor(expected, inputs, reals()) d = Variable('value', reals()) if syntax == 'eager': actual = dist.BernoulliLogits(logits, value) elif syntax == 'lazy': actual = dist.BernoulliLogits(logits, d)(value=value) elif syntax == 'generic': actual = dist.Bernoulli(logits=logits)(value=value) check_funsor(actual, inputs, reals()) assert_close(actual, expected)