def test_elbo_plate_plate(backend, outer_dim, inner_dim): with pyro_backend(backend): pyro.get_param_store().clear() num_particles = 1 q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True)) p = 0.2693204236205713 # for which kl(Categorical(q), Categorical(p)) = 0.5 p = torch.tensor([p, 1 - p]) def model(): d = dist.Categorical(p) context1 = pyro.plate("outer", outer_dim, dim=-1) context2 = pyro.plate("inner", inner_dim, dim=-2) pyro.sample("w", d) with context1: pyro.sample("x", d) with context2: pyro.sample("y", d) with context1, context2: pyro.sample("z", d) def guide(): d = dist.Categorical(pyro.param("q")) context1 = pyro.plate("outer", outer_dim, dim=-1) context2 = pyro.plate("inner", inner_dim, dim=-2) pyro.sample("w", d, infer={"enumerate": "parallel"}) with context1: pyro.sample("x", d, infer={"enumerate": "parallel"}) with context2: pyro.sample("y", d, infer={"enumerate": "parallel"}) with context1, context2: pyro.sample("z", d, infer={"enumerate": "parallel"}) kl_node = kl_divergence( torch.distributions.Categorical(funsor.to_data(q)), torch.distributions.Categorical(funsor.to_data(p))) kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node expected_loss = kl expected_grad = grad(kl, [funsor.to_data(q)])[0] elbo = infer.TraceEnum_ELBO(num_particles=num_particles, vectorize_particles=True, strict_enumeration_warning=True) elbo = elbo.differentiable_loss if backend == "pyro" else elbo actual_loss = funsor.to_data(elbo(model, guide)) actual_loss.backward() actual_grad = funsor.to_data(pyro.param('q')).grad assert_close(actual_loss, expected_loss, atol=1e-5) assert_close(actual_grad, expected_grad, atol=1e-5)
def test_elbo_enumerate_plate_7(backend): # Guide Model # a -----> b # | | # +-|--------|----------------+ # | V V | # | c -----> d -----> e N=2 | # +---------------------------+ # This tests a mixture of model and guide enumeration. with pyro_backend(backend): pyro.param("model_probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) pyro.param("model_probs_b", torch.tensor([[0.6, 0.4], [0.4, 0.6]]), constraint=constraints.simplex) pyro.param("model_probs_c", torch.tensor([[0.75, 0.25], [0.55, 0.45]]), constraint=constraints.simplex) pyro.param("model_probs_d", torch.tensor([[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]), constraint=constraints.simplex) pyro.param("model_probs_e", torch.tensor([[0.75, 0.25], [0.55, 0.45]]), constraint=constraints.simplex) pyro.param("guide_probs_a", torch.tensor([0.35, 0.64]), constraint=constraints.simplex) pyro.param( "guide_probs_c", torch.tensor([[0., 1.], [1., 0.]]), # deterministic constraint=constraints.simplex) def auto_model(data): probs_a = pyro.param("model_probs_a") probs_b = pyro.param("model_probs_b") probs_c = pyro.param("model_probs_c") probs_d = pyro.param("model_probs_d") probs_e = pyro.param("model_probs_e") a = pyro.sample("a", dist.Categorical(probs_a)) b = pyro.sample("b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"}) with pyro.plate("data", 2, dim=-1): c = pyro.sample("c", dist.Categorical(probs_c[a])) d = pyro.sample("d", dist.Categorical(Vindex(probs_d)[b, c]), infer={"enumerate": "parallel"}) pyro.sample("obs", dist.Categorical(probs_e[d]), obs=data) def auto_guide(data): probs_a = pyro.param("guide_probs_a") probs_c = pyro.param("guide_probs_c") a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) with pyro.plate("data", 2, dim=-1): pyro.sample("c", dist.Categorical(probs_c[a])) def hand_model(data): probs_a = pyro.param("model_probs_a") probs_b = pyro.param("model_probs_b") probs_c = pyro.param("model_probs_c") probs_d = pyro.param("model_probs_d") probs_e = pyro.param("model_probs_e") a = pyro.sample("a", dist.Categorical(probs_a)) b = pyro.sample("b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"}) for i in range(2): c = pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a])) d = pyro.sample("d_{}".format(i), dist.Categorical(Vindex(probs_d)[b, c]), infer={"enumerate": "parallel"}) pyro.sample("obs_{}".format(i), dist.Categorical(probs_e[d]), obs=data[i]) def hand_guide(data): probs_a = pyro.param("guide_probs_a") probs_c = pyro.param("guide_probs_c") a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) for i in range(2): pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a])) data = torch.tensor([0, 0]) elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) elbo = elbo.differentiable_loss if backend == "pyro" else elbo auto_loss = elbo(auto_model, auto_guide, data) elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) elbo = elbo.differentiable_loss if backend == "pyro" else elbo hand_loss = elbo(hand_model, hand_guide, data) _check_loss_and_grads(hand_loss, auto_loss)
def test_elbo_enumerate_plates_1(backend): # +-----------------+ # | a ----> b M=2 | # +-----------------+ # +-----------------+ # | c ----> d N=3 | # +-----------------+ # This tests two unrelated plates. # Each should remain uncontracted. with pyro_backend(backend): pyro.param("probs_a", torch.tensor([0.45, 0.55]), constraint=constraints.simplex) pyro.param("probs_b", torch.tensor([[0.6, 0.4], [0.4, 0.6]]), constraint=constraints.simplex) pyro.param("probs_c", torch.tensor([0.75, 0.25]), constraint=constraints.simplex) pyro.param("probs_d", torch.tensor([[0.4, 0.6], [0.3, 0.7]]), constraint=constraints.simplex) b_data = torch.tensor([0, 1]) d_data = torch.tensor([0, 0, 1]) def auto_model(): probs_a = pyro.param("probs_a") probs_b = pyro.param("probs_b") probs_c = pyro.param("probs_c") probs_d = pyro.param("probs_d") with pyro.plate("a_axis", 2, dim=-1): a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) pyro.sample("b", dist.Categorical(probs_b[a]), obs=b_data) with pyro.plate("c_axis", 3, dim=-1): c = pyro.sample("c", dist.Categorical(probs_c), infer={"enumerate": "parallel"}) pyro.sample("d", dist.Categorical(probs_d[c]), obs=d_data) def hand_model(): probs_a = pyro.param("probs_a") probs_b = pyro.param("probs_b") probs_c = pyro.param("probs_c") probs_d = pyro.param("probs_d") for i in range(2): a = pyro.sample("a_{}".format(i), dist.Categorical(probs_a), infer={"enumerate": "parallel"}) pyro.sample("b_{}".format(i), dist.Categorical(probs_b[a]), obs=b_data[i]) for j in range(3): c = pyro.sample("c_{}".format(j), dist.Categorical(probs_c), infer={"enumerate": "parallel"}) pyro.sample("d_{}".format(j), dist.Categorical(probs_d[c]), obs=d_data[j]) def guide(): pass elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) elbo = elbo.differentiable_loss if backend == "pyro" else elbo auto_loss = elbo(auto_model, guide) elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) elbo = elbo.differentiable_loss if backend == "pyro" else elbo hand_loss = elbo(hand_model, guide) _check_loss_and_grads(hand_loss, auto_loss)