Exemplo n.º 1
0
def test_arrowhead_mass():
    def model(prec):
        w = pyro.sample("w", dist.Normal(0, 1000).expand([2]).to_event(1))
        x = pyro.sample("x", dist.Normal(0, 1000).expand([1]).to_event(1))
        y = pyro.sample("y", dist.Normal(0, 1000).expand([1]).to_event(1))
        z = pyro.sample("z", dist.Normal(0, 1000).expand([2]).to_event(1))
        wyxz = torch.cat([w, y, x, z])
        pyro.sample("obs", dist.MultivariateNormal(torch.zeros(6), precision_matrix=prec), obs=wyxz)

    A = torch.randn(6, 12)
    prec = A @ A.t() * 0.1

    # smoke tests
    for dense_mass in [True, False]:
        kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True, full_mass=dense_mass)
        mcmc = MCMC(kernel, num_samples=1, warmup_steps=1)
        mcmc.run(prec)
        assert kernel.inverse_mass_matrix[("w", "x", "y", "z")].dim() == 1 + int(dense_mass)

    kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True, full_mass=[("w",), ("y", "x")])
    kernel.mass_matrix_adapter = ArrowheadMassMatrix()
    mcmc = MCMC(kernel, num_samples=1, warmup_steps=1000)
    mcmc.run(prec)
    assert ("w", "y", "x", "z") in kernel.inverse_mass_matrix
    mass_matrix = kernel.mass_matrix_adapter.mass_matrix[("w", "y", "x", "z")]
    assert mass_matrix.top.shape == (4, 6)
    assert mass_matrix.bottom_diag.shape == (2,)
    assert_close(mass_matrix.top, prec[:4], atol=0.2, rtol=0.2)
    assert_close(mass_matrix.bottom_diag, prec.diag()[4:], atol=0.2, rtol=0.2)
Exemplo n.º 2
0
def test_dirichlet_categorical_grad_adapt():
    def model(data):
        concentration = torch.tensor([1.0, 1.0, 1.0])
        p_latent = pyro.sample("p_latent", dist.Dirichlet(concentration))
        pyro.sample("obs", dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = torch.tensor([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(sample_shape=(torch.Size((2000,))))
    nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True)
    nuts_kernel.mass_matrix_adapter = ArrowheadMassMatrix()
    mcmc = MCMC(nuts_kernel, num_samples=200, warmup_steps=100)
    mcmc.run(data)
    samples = mcmc.get_samples()
    posterior = samples["p_latent"]
    assert_equal(posterior.mean(0), true_probs, prec=0.02)