コード例 #1
0
def _check_loss_and_grads(expected_loss, actual_loss, atol=1e-4, rtol=1e-4):
    # copied from pyro
    expected_loss, actual_loss = funsor.to_data(expected_loss), funsor.to_data(actual_loss)
    assert ops.allclose(actual_loss, expected_loss, atol=atol, rtol=rtol)
    names = pyro.get_param_store().keys()
    params = []
    for name in names:
        params.append(funsor.to_data(pyro.param(name)).unconstrained())
    actual_grads = grad(actual_loss, params, allow_unused=True, retain_graph=True)
    expected_grads = grad(expected_loss, params, allow_unused=True, retain_graph=True)
    for name, actual_grad, expected_grad in zip(names, actual_grads, expected_grads):
        if actual_grad is None or expected_grad is None:
            continue
        assert ops.allclose(actual_grad, expected_grad, atol=atol, rtol=rtol)
コード例 #2
0
def test_elbo_plate_plate(backend, outer_dim, inner_dim):
    with pyro_backend(backend):
        pyro.get_param_store().clear()
        num_particles = 1
        q = pyro.param("q", torch.tensor([0.75, 0.25], requires_grad=True))
        p = 0.2693204236205713  # for which kl(Categorical(q), Categorical(p)) = 0.5
        p = torch.tensor([p, 1 - p])

        def model():
            d = dist.Categorical(p)
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d)
            with context1:
                pyro.sample("x", d)
            with context2:
                pyro.sample("y", d)
            with context1, context2:
                pyro.sample("z", d)

        def guide():
            d = dist.Categorical(pyro.param("q"))
            context1 = pyro.plate("outer", outer_dim, dim=-1)
            context2 = pyro.plate("inner", inner_dim, dim=-2)
            pyro.sample("w", d, infer={"enumerate": "parallel"})
            with context1:
                pyro.sample("x", d, infer={"enumerate": "parallel"})
            with context2:
                pyro.sample("y", d, infer={"enumerate": "parallel"})
            with context1, context2:
                pyro.sample("z", d, infer={"enumerate": "parallel"})

        kl_node = kl_divergence(
            torch.distributions.Categorical(funsor.to_data(q)),
            torch.distributions.Categorical(funsor.to_data(p)))
        kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node
        expected_loss = kl
        expected_grad = grad(kl, [funsor.to_data(q)])[0]

        elbo = infer.TraceEnum_ELBO(num_particles=num_particles,
                                    vectorize_particles=True,
                                    strict_enumeration_warning=True)
        elbo = elbo.differentiable_loss if backend == "pyro" else elbo
        actual_loss = funsor.to_data(elbo(model, guide))
        actual_loss.backward()
        actual_grad = funsor.to_data(pyro.param('q')).grad

        assert ops.allclose(actual_loss, expected_loss, atol=1e-5)
        assert ops.allclose(actual_grad, expected_grad, atol=1e-5)
コード例 #3
0
def test_rng_seed(backend):
    def model():
        return pyro.sample("x", dist.Normal(0, 1))

    with pyro_backend(backend):
        with handlers.seed(rng_seed=0):
            expected = model()
        with handlers.seed(rng_seed=0):
            actual = model()
        assert ops.allclose(actual, expected)
コード例 #4
0
def test_local_param_ok(backend, jit):
    data = torch.randn(10)

    def model():
        locs = pyro.param("locs", torch.tensor([-1., 0., 1.]))
        with pyro.plate("plate", len(data), dim=-1):
            x = pyro.sample("x", dist.Categorical(torch.ones(3) / 3))
            pyro.sample("obs", dist.Normal(locs[x], 1.), obs=data)

    def guide():
        with pyro.plate("plate", len(data), dim=-1):
            p = pyro.param("p", torch.ones(len(data), 3) / 3, event_dim=1)
            pyro.sample("x", dist.Categorical(p))
        return p

    with pyro_backend(backend):
        Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
        elbo = Elbo(ignore_jit_warnings=True)
        assert_ok(model, guide, elbo)

        # Check that pyro.param() can be called without init_value.
        expected = guide()
        actual = pyro.param("p")
        assert ops.allclose(actual, expected)