예제 #1
0
def test_elbo_plate_plate(backend, outer_dim, inner_dim):
    with pyro_backend(backend):
        pyro.get_param_store().clear()
        num_particles = 1
        q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True))
        p = 0.2693204236205713  # for which kl(Categorical(q), Categorical(p)) = 0.5
        p = torch.tensor([p, 1 - p])

        def model():
            d = dist.Categorical(p)
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d)
            with context1:
                pyro.sample("x", d)
            with context2:
                pyro.sample("y", d)
            with context1, context2:
                pyro.sample("z", d)

        def guide():
            d = dist.Categorical(pyro.param("q"))
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d, infer={"enumerate": "parallel"})
            with context1:
                pyro.sample("x", d, infer={"enumerate": "parallel"})
            with context2:
                pyro.sample("y", d, infer={"enumerate": "parallel"})
            with context1, context2:
                pyro.sample("z", d, infer={"enumerate": "parallel"})

        kl_node = kl_divergence(
            torch.distributions.Categorical(funsor.to_data(q)),
            torch.distributions.Categorical(funsor.to_data(p)))
        kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node
        expected_loss = kl
        expected_grad = grad(kl, [funsor.to_data(q)])[0]

        elbo = infer.TraceEnum_ELBO(num_particles=num_particles,
                                    vectorize_particles=True,
                                    strict_enumeration_warning=True)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        actual_loss = funsor.to_data(elbo(model, guide))
        actual_loss.backward()
        actual_grad = funsor.to_data(pyro.param('q')).grad

        assert_close(actual_loss, expected_loss, atol=1e-5)
        assert_close(actual_grad, expected_grad, atol=1e-5)
예제 #2
0
def test_elbo_enumerate_plate_7(backend):
    #  Guide    Model
    #    a -----> b
    #    |        |
    #  +-|--------|----------------+
    #  | V        V                |
    #  | c -----> d -----> e   N=2 |
    #  +---------------------------+
    # This tests a mixture of model and guide enumeration.
    with pyro_backend(backend):
        pyro.param("model_probs_a",
                   torch.tensor([0.45, 0.55]),
                   constraint=constraints.simplex)
        pyro.param("model_probs_b",
                   torch.tensor([[0.6, 0.4], [0.4, 0.6]]),
                   constraint=constraints.simplex)
        pyro.param("model_probs_c",
                   torch.tensor([[0.75, 0.25], [0.55, 0.45]]),
                   constraint=constraints.simplex)
        pyro.param("model_probs_d",
                   torch.tensor([[[0.4, 0.6], [0.3, 0.7]],
                                 [[0.3, 0.7], [0.2, 0.8]]]),
                   constraint=constraints.simplex)
        pyro.param("model_probs_e",
                   torch.tensor([[0.75, 0.25], [0.55, 0.45]]),
                   constraint=constraints.simplex)
        pyro.param("guide_probs_a",
                   torch.tensor([0.35, 0.64]),
                   constraint=constraints.simplex)
        pyro.param(
            "guide_probs_c",
            torch.tensor([[0., 1.], [1., 0.]]),  # deterministic
            constraint=constraints.simplex)

        def auto_model(data):
            probs_a = pyro.param("model_probs_a")
            probs_b = pyro.param("model_probs_b")
            probs_c = pyro.param("model_probs_c")
            probs_d = pyro.param("model_probs_d")
            probs_e = pyro.param("model_probs_e")
            a = pyro.sample("a", dist.Categorical(probs_a))
            b = pyro.sample("b",
                            dist.Categorical(probs_b[a]),
                            infer={"enumerate": "parallel"})
            with pyro.plate("data", 2, dim=-1):
                c = pyro.sample("c", dist.Categorical(probs_c[a]))
                d = pyro.sample("d",
                                dist.Categorical(Vindex(probs_d)[b, c]),
                                infer={"enumerate": "parallel"})
                pyro.sample("obs", dist.Categorical(probs_e[d]), obs=data)

        def auto_guide(data):
            probs_a = pyro.param("guide_probs_a")
            probs_c = pyro.param("guide_probs_c")
            a = pyro.sample("a",
                            dist.Categorical(probs_a),
                            infer={"enumerate": "parallel"})
            with pyro.plate("data", 2, dim=-1):
                pyro.sample("c", dist.Categorical(probs_c[a]))

        def hand_model(data):
            probs_a = pyro.param("model_probs_a")
            probs_b = pyro.param("model_probs_b")
            probs_c = pyro.param("model_probs_c")
            probs_d = pyro.param("model_probs_d")
            probs_e = pyro.param("model_probs_e")
            a = pyro.sample("a", dist.Categorical(probs_a))
            b = pyro.sample("b",
                            dist.Categorical(probs_b[a]),
                            infer={"enumerate": "parallel"})
            for i in range(2):
                c = pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a]))
                d = pyro.sample("d_{}".format(i),
                                dist.Categorical(Vindex(probs_d)[b, c]),
                                infer={"enumerate": "parallel"})
                pyro.sample("obs_{}".format(i),
                            dist.Categorical(probs_e[d]),
                            obs=data[i])

        def hand_guide(data):
            probs_a = pyro.param("guide_probs_a")
            probs_c = pyro.param("guide_probs_c")
            a = pyro.sample("a",
                            dist.Categorical(probs_a),
                            infer={"enumerate": "parallel"})
            for i in range(2):
                pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a]))

        data = torch.tensor([0, 0])
        elbo = infer.TraceEnum_ELBO(max_plate_nesting=1)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        auto_loss = elbo(auto_model, auto_guide, data)
        elbo = infer.TraceEnum_ELBO(max_plate_nesting=0)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        hand_loss = elbo(hand_model, hand_guide, data)
        _check_loss_and_grads(hand_loss, auto_loss)
예제 #3
0
def test_elbo_enumerate_plates_1(backend):
    #  +-----------------+
    #  | a ----> b   M=2 |
    #  +-----------------+
    #  +-----------------+
    #  | c ----> d   N=3 |
    #  +-----------------+
    # This tests two unrelated plates.
    # Each should remain uncontracted.
    with pyro_backend(backend):
        pyro.param("probs_a",
                   torch.tensor([0.45, 0.55]),
                   constraint=constraints.simplex)
        pyro.param("probs_b",
                   torch.tensor([[0.6, 0.4], [0.4, 0.6]]),
                   constraint=constraints.simplex)
        pyro.param("probs_c",
                   torch.tensor([0.75, 0.25]),
                   constraint=constraints.simplex)
        pyro.param("probs_d",
                   torch.tensor([[0.4, 0.6], [0.3, 0.7]]),
                   constraint=constraints.simplex)
        b_data = torch.tensor([0, 1])
        d_data = torch.tensor([0, 0, 1])

        def auto_model():
            probs_a = pyro.param("probs_a")
            probs_b = pyro.param("probs_b")
            probs_c = pyro.param("probs_c")
            probs_d = pyro.param("probs_d")
            with pyro.plate("a_axis", 2, dim=-1):
                a = pyro.sample("a",
                                dist.Categorical(probs_a),
                                infer={"enumerate": "parallel"})
                pyro.sample("b", dist.Categorical(probs_b[a]), obs=b_data)
            with pyro.plate("c_axis", 3, dim=-1):
                c = pyro.sample("c",
                                dist.Categorical(probs_c),
                                infer={"enumerate": "parallel"})
                pyro.sample("d", dist.Categorical(probs_d[c]), obs=d_data)

        def hand_model():
            probs_a = pyro.param("probs_a")
            probs_b = pyro.param("probs_b")
            probs_c = pyro.param("probs_c")
            probs_d = pyro.param("probs_d")
            for i in range(2):
                a = pyro.sample("a_{}".format(i),
                                dist.Categorical(probs_a),
                                infer={"enumerate": "parallel"})
                pyro.sample("b_{}".format(i),
                            dist.Categorical(probs_b[a]),
                            obs=b_data[i])
            for j in range(3):
                c = pyro.sample("c_{}".format(j),
                                dist.Categorical(probs_c),
                                infer={"enumerate": "parallel"})
                pyro.sample("d_{}".format(j),
                            dist.Categorical(probs_d[c]),
                            obs=d_data[j])

        def guide():
            pass

        elbo = infer.TraceEnum_ELBO(max_plate_nesting=1)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        auto_loss = elbo(auto_model, guide)
        elbo = infer.TraceEnum_ELBO(max_plate_nesting=0)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        hand_loss = elbo(hand_model, guide)
        _check_loss_and_grads(hand_loss, auto_loss)