def test_prob(nderivs): # +-------+ # z --|--> x | # +-------+ num_particles = 10000 data = torch.tensor([0.5, 1., 1.5]) p = pyro.param("p", torch.tensor(0.25)) @config_enumerate def model(num_particles): p = pyro.param("p") with pyro.plate("num_particles", num_particles, dim=-2): z = pyro.sample("z", dist.Bernoulli(p)) with pyro.plate("data", 3): pyro.sample("x", dist.Normal(z, 1.), obs=data) def guide(num_particles): pass elbo = TraceEnum_ELBO(max_plate_nesting=2) expected_logprob = -elbo.differentiable_loss(model, guide, num_particles=1) posterior_model = infer_discrete(config_enumerate(model, "parallel"), first_available_dim=-3) posterior_trace = poutine.trace(posterior_model).get_trace( num_particles=num_particles) actual_logprob = log_mean_prob(posterior_trace, particle_dim=-2) if nderivs == 0: assert_equal(expected_logprob, actual_logprob, prec=1e-3) elif nderivs == 1: expected_grad = grad(expected_logprob, [p])[0] actual_grad = grad(actual_logprob, [p])[0] assert_equal(expected_grad, actual_grad, prec=1e-3)
def test_distribution_1(temperature): # +-------+ # z --|--> x | # +-------+ num_particles = 10000 data = torch.tensor([1., 2., 3.]) @config_enumerate def model(num_particles=1, z=None): p = pyro.param("p", torch.tensor(0.25)) with pyro.plate("num_particles", num_particles, dim=-2): z = pyro.sample("z", dist.Bernoulli(p), obs=z) logger.info("z.shape = {}".format(z.shape)) with pyro.plate("data", 3): pyro.sample("x", dist.Normal(z, 1.), obs=data) first_available_dim = -3 sampled_model = infer_discrete(model, first_available_dim, temperature) sampled_trace = poutine.trace(sampled_model).get_trace(num_particles) conditioned_traces = { z: poutine.trace(model).get_trace(z=torch.tensor(z)) for z in [0., 1.] } # Check posterior over z. actual_z_mean = sampled_trace.nodes["z"]["value"].mean() if temperature: expected_z_mean = 1 / (1 + (conditioned_traces[0].log_prob_sum() - conditioned_traces[1].log_prob_sum()).exp()) else: expected_z_mean = (conditioned_traces[1].log_prob_sum() > conditioned_traces[0].log_prob_sum()).float() assert_equal(actual_z_mean, expected_z_mean, prec=1e-2)
def test_hmm_smoke(temperature, length): # This should match the example in the infer_discrete docstring. def hmm(data, hidden_dim=10): transition = 0.3 / hidden_dim + 0.7 * torch.eye(hidden_dim) means = torch.arange(float(hidden_dim)) states = [0] for t in pyro.markov(range(len(data))): states.append( pyro.sample("states_{}".format(t), dist.Categorical(transition[states[-1]]))) data[t] = pyro.sample("obs_{}".format(t), dist.Normal(means[states[-1]], 1.), obs=data[t]) return states, data true_states, data = hmm([None] * length) assert len(data) == length assert len(true_states) == 1 + len(data) decoder = infer_discrete(config_enumerate(hmm), first_available_dim=-1, temperature=temperature) inferred_states, _ = decoder(data) assert len(inferred_states) == len(true_states) logger.info("true states: {}".format(list(map(int, true_states)))) logger.info("inferred states: {}".format(list(map(int, inferred_states))))
def test_warning(): data = torch.randn(4) def model(): x = pyro.sample("x", dist.Categorical(torch.ones(3))) with pyro.plate("data", len(data)): pyro.sample("obs", dist.Normal(x.float(), 1), obs=data) model_1 = infer_discrete(model, first_available_dim=-2) model_2 = infer_discrete(model, first_available_dim=-2, strict_enumeration_warning=False) model_3 = infer_discrete(config_enumerate(model), first_available_dim=-2) # model_1 should raise warnings. with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") model_1() assert w, 'No warnings were raised' # model_2 and model_3 should both be valid. model_2() model_3()
def test_distribution_3(temperature): # +---------+ +---------------+ # z1 --|--> x1 | | z2 ---> x2 | # | 3 | | 2 | # +---------+ +---------------+ num_particles = 10000 data = [torch.tensor([-1., -1., 0.]), torch.tensor([-1., 1.])] @config_enumerate def model(num_particles=1, z1=None, z2=None): p = pyro.param("p", torch.tensor([0.25, 0.75])) loc = pyro.param("loc", torch.tensor([-1., 1.])) with pyro.plate("num_particles", num_particles, dim=-2): z1 = pyro.sample("z1", dist.Categorical(p), obs=z1) with pyro.plate("data[0]", 3): pyro.sample("x1", dist.Normal(loc[z1], 1.), obs=data[0]) with pyro.plate("data[1]", 2): z2 = pyro.sample("z2", dist.Categorical(p), obs=z2) pyro.sample("x2", dist.Normal(loc[z2], 1.), obs=data[1]) first_available_dim = -3 sampled_model = infer_discrete(model, first_available_dim, temperature) sampled_trace = poutine.trace(sampled_model).get_trace(num_particles) conditioned_traces = { (z1, z20, z21): poutine.trace(model).get_trace(z1=torch.tensor(z1), z2=torch.tensor([z20, z21])) for z1 in [0, 1] for z20 in [0, 1] for z21 in [0, 1] } # Check joint posterior over (z1, z2[0], z2[1]). actual_probs = torch.empty(2, 2, 2) expected_probs = torch.empty(2, 2, 2) for (z1, z20, z21), tr in conditioned_traces.items(): expected_probs[z1, z20, z21] = tr.log_prob_sum().exp() actual_probs[z1, z20, z21] = ( (sampled_trace.nodes["z1"]["value"] == z1) & (sampled_trace.nodes["z2"]["value"][..., :1] == z20) & (sampled_trace.nodes["z2"]["value"][..., 1:] == z21)).float().mean() if temperature: expected_probs = expected_probs / expected_probs.sum() else: argmax = expected_probs.reshape(-1).max(0)[1] expected_probs[:] = 0 expected_probs.reshape(-1)[argmax] = 1 assert_equal(expected_probs.reshape(-1), actual_probs.reshape(-1), prec=1e-2)