def test_dirichlet_categorical_x64(kernel_cls, dense_mass): num_warmup, num_samples = 100, 20000 def model(data): concentration = jnp.array([1.0, 1.0, 1.0]) p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration)) numpyro.sample("obs", dist.Categorical(p_latent), obs=data) return p_latent true_probs = jnp.array([0.1, 0.6, 0.3]) data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000, )) if kernel_cls is BarkerMH: kernel = BarkerMH(model=model, dense_mass=dense_mass) else: kernel = kernel_cls(model, trajectory_length=1.0, dense_mass=dense_mass) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.02) if "JAX_ENABLE_X64" in os.environ: assert samples["p_latent"].dtype == jnp.float64
def test_dense_mass(kernel_cls, rho): warmup_steps, num_samples = 20000, 10000 true_cov = jnp.array([[10.0, rho], [rho, 0.1]]) def model(): numpyro.sample( "x", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=true_cov) ) if kernel_cls is HMC or kernel_cls is NUTS: kernel = kernel_cls(model, trajectory_length=2.0, dense_mass=True) elif kernel_cls is BarkerMH: kernel = BarkerMH(model, dense_mass=True) mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False) mcmc.run(random.PRNGKey(0)) mass_matrix_sqrt = mcmc.last_state.adapt_state.mass_matrix_sqrt if kernel_cls is HMC or kernel_cls is NUTS: mass_matrix_sqrt = mass_matrix_sqrt[("x",)] mass_matrix = jnp.matmul(mass_matrix_sqrt, jnp.transpose(mass_matrix_sqrt)) estimated_cov = jnp.linalg.inv(mass_matrix) assert_allclose(estimated_cov, true_cov, rtol=0.10) samples = mcmc.get_samples()["x"] assert_allclose(jnp.mean(samples[:, 0]), jnp.array(0.0), atol=0.50) assert_allclose(jnp.mean(samples[:, 1]), jnp.array(0.0), atol=0.05) assert_allclose(jnp.mean(samples[:, 0] * samples[:, 1]), jnp.array(rho), atol=0.20) assert_allclose(jnp.var(samples, axis=0), jnp.array([10.0, 0.1]), rtol=0.20)
def test_beta_bernoulli_x64(kernel_cls): num_warmup, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000) def model(data): alpha = jnp.array([1.1, 1.1]) beta = jnp.array([1.1, 1.1]) p_latent = numpyro.sample("p_latent", dist.Beta(alpha, beta)) numpyro.sample("obs", dist.Bernoulli(p_latent), obs=data) return p_latent true_probs = jnp.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, )) if kernel_cls is SA: kernel = SA(model=model) elif kernel_cls is BarkerMH: kernel = BarkerMH(model=model) else: kernel = kernel_cls(model=model, trajectory_length=0.1) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data) mcmc.print_summary() samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.05) if "JAX_ENABLE_X64" in os.environ: assert samples["p_latent"].dtype == jnp.float64
def test_logistic_regression_x64(kernel_cls): N, dim = 3000, 3 if kernel_cls is SA: warmup_steps, num_samples = (100000, 100000) elif kernel_cls is BarkerMH: warmup_steps, num_samples = (2000, 12000) else: warmup_steps, num_samples = (1000, 8000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1., dim + 1.) logits = jnp.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) def model(labels): coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim))) logits = numpyro.deterministic('logits', jnp.sum(coefs * data, axis=-1)) return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels) if kernel_cls is SA: kernel = SA(model=model, adapt_state_size=9) elif kernel_cls is BarkerMH: kernel = BarkerMH(model=model) else: kernel = kernel_cls(model=model, trajectory_length=8, find_heuristic_step_size=True) mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() assert samples['logits'].shape == (num_samples, N) # those coefficients are found by doing MAP inference using AutoDelta expected_coefs = jnp.array([0.97, 2.05, 3.18]) assert_allclose(jnp.mean(samples['coefs'], 0), expected_coefs, atol=0.1) if 'JAX_ENABLE_X64' in os.environ: assert samples['coefs'].dtype == jnp.float64