def test_arrowhead_mass(): def model(prec): w = pyro.sample("w", dist.Normal(0, 1000).expand([2]).to_event(1)) x = pyro.sample("x", dist.Normal(0, 1000).expand([1]).to_event(1)) y = pyro.sample("y", dist.Normal(0, 1000).expand([1]).to_event(1)) z = pyro.sample("z", dist.Normal(0, 1000).expand([2]).to_event(1)) wyxz = torch.cat([w, y, x, z]) pyro.sample("obs", dist.MultivariateNormal(torch.zeros(6), precision_matrix=prec), obs=wyxz) A = torch.randn(6, 12) prec = A @ A.t() * 0.1 # smoke tests for dense_mass in [True, False]: kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True, full_mass=dense_mass) mcmc = MCMC(kernel, num_samples=1, warmup_steps=1) mcmc.run(prec) assert kernel.inverse_mass_matrix[("w", "x", "y", "z")].dim() == 1 + int(dense_mass) kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True, full_mass=[("w",), ("y", "x")]) kernel.mass_matrix_adapter = ArrowheadMassMatrix() mcmc = MCMC(kernel, num_samples=1, warmup_steps=1000) mcmc.run(prec) assert ("w", "y", "x", "z") in kernel.inverse_mass_matrix mass_matrix = kernel.mass_matrix_adapter.mass_matrix[("w", "y", "x", "z")] assert mass_matrix.top.shape == (4, 6) assert mass_matrix.bottom_diag.shape == (2,) assert_close(mass_matrix.top, prec[:4], atol=0.2, rtol=0.2) assert_close(mass_matrix.bottom_diag, prec.diag()[4:], atol=0.2, rtol=0.2)
def test_dirichlet_categorical_grad_adapt(): def model(data): concentration = torch.tensor([1.0, 1.0, 1.0]) p_latent = pyro.sample("p_latent", dist.Dirichlet(concentration)) pyro.sample("obs", dist.Categorical(p_latent), obs=data) return p_latent true_probs = torch.tensor([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(sample_shape=(torch.Size((2000,)))) nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True) nuts_kernel.mass_matrix_adapter = ArrowheadMassMatrix() mcmc = MCMC(nuts_kernel, num_samples=200, warmup_steps=100) mcmc.run(data) samples = mcmc.get_samples() posterior = samples["p_latent"] assert_equal(posterior.mean(0), true_probs, prec=0.02)