def test_uniform_normal(forward_mode_differentiation): true_coef = 0.9 num_warmup, num_samples = 1000, 1000 def model(data): alpha = numpyro.sample("alpha", dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={"loc": TransformReparam()}): loc = numpyro.sample( "loc", dist.TransformedDistribution(dist.Uniform(0, 1), AffineTransform(0, alpha)), ) numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000, )) kernel = NUTS(model=model, forward_mode_differentiation=forward_mode_differentiation) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True) assert mcmc.post_warmup_state is not None warmup_samples = mcmc.get_samples() mcmc.run(random.PRNGKey(3), data) samples = mcmc.get_samples() assert len(warmup_samples["loc"]) == num_warmup assert len(samples["loc"]) == num_samples assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05) mcmc.post_warmup_state = mcmc.last_state mcmc.run(random.PRNGKey(3), data) samples = mcmc.get_samples() assert len(samples["loc"]) == num_samples assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)
def test_uniform_normal(): true_coef = 0.9 num_warmup, num_samples = 1000, 1000 def model(data): alpha = numpyro.sample('alpha', dist.Uniform(0, 1)) with numpyro.handlers.reparam(config={'loc': TransformReparam()}): loc = numpyro.sample('loc', dist.Uniform(0, alpha)) numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data) data = true_coef + random.normal(random.PRNGKey(0), (1000,)) kernel = NUTS(model=model) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True) assert mcmc.post_warmup_state is not None warmup_samples = mcmc.get_samples() mcmc.run(random.PRNGKey(3), data) samples = mcmc.get_samples() assert len(warmup_samples['loc']) == num_warmup assert len(samples['loc']) == num_samples assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05) mcmc.post_warmup_state = mcmc.last_state mcmc.run(random.PRNGKey(3), data) samples = mcmc.get_samples() assert len(samples['loc']) == num_samples assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)