def test_unnormalized_normal_chain(kernel, kwargs, num_chains): from numpyro.contrib.tfp import mcmc # TODO: remove when this issue is fixed upstream # https://github.com/tensorflow/probability/pull/1087 if num_chains == 2 and kernel == "ReplicaExchangeMC": pytest.xfail( "ReplicaExchangeMC is not fully compatible with omnistaging yet.") kernel_class = getattr(mcmc, kernel) true_mean, true_std = 1., 0.5 warmup_steps, num_samples = (1000, 8000) def potential_fn(z): return 0.5 * ((z - true_mean) / true_std)**2 init_params = jnp.array(0.) if num_chains == 1 else jnp.array([0., 2.]) tfp_kernel = kernel_class(potential_fn=potential_fn, **kwargs) mcmc = MCMC(tfp_kernel, warmup_steps, num_samples, num_chains=num_chains, progress_bar=False) mcmc.run(random.PRNGKey(0), init_params=init_params) mcmc.print_summary() hmc_states = mcmc.get_samples() assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07) assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07)
def test_mcmc_kernels(kernel, kwargs): from numpyro.contrib.tfp import mcmc kernel_class = getattr(mcmc, kernel) true_coef = 0.9 num_warmup, num_samples = 1000, 1000 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample( 'loc', dist.TransformedDistribution(dist.Uniform(0, 1), AffineTransform(0, alpha))) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000, )) tfp_kernel = kernel_class(model=model, **kwargs) mcmc = MCMC(tfp_kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True) warmup_samples = mcmc.get_samples() mcmc.run(random.PRNGKey(3), data) samples = mcmc.get_samples() assert len(warmup_samples['loc']) == num_warmup assert len(samples['loc']) == num_samples assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
def test_mcmc_kernels(kernel, kwargs): from numpyro.contrib.tfp import mcmc if ("CI" in os.environ) and kernel == "SliceSampler": # TODO: Look into this issue if some users are using SliceSampler # with NumPyro model. pytest.skip("SliceSampler freezes CI for unknown reason.") kernel_class = getattr(mcmc, kernel) true_coef = 0.9 num_warmup, num_samples = 1000, 1000 def model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution(dist.Uniform(0, 1), AffineTransform(0, alpha)), ) numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000, )) tfp_kernel = kernel_class(model=model, **kwargs) mcmc = MCMC(tfp_kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True) warmup_samples = mcmc.get_samples() mcmc.run(random.PRNGKey(3), data) samples = mcmc.get_samples() assert len(warmup_samples["loc"]) == num_warmup assert len(samples["loc"]) == num_samples assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)
def test_mcmc_unwrapped_tfp_distributions(): from tensorflow_probability.substrates.jax import distributions as tfd def model(y): theta = numpyro.sample("p", tfd.Beta(1, 1)) with numpyro.plate("plate", y.size): numpyro.sample("y", tfd.Bernoulli(probs=theta), obs=y) mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000) mcmc.run(random.PRNGKey(0), jnp.array([0, 0, 1, 1, 1])) samples = mcmc.get_samples() assert_allclose(jnp.mean(samples["p"]), 4 / 7, atol=0.05)