コード例 #1
0
def test_uniform_normal(forward_mode_differentiation):
    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={"loc": TransformReparam()}):
            loc = numpyro.sample(
                "loc",
                dist.TransformedDistribution(dist.Uniform(0, 1),
                                             AffineTransform(0, alpha)),
            )
        numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    kernel = NUTS(model=model,
                  forward_mode_differentiation=forward_mode_differentiation)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
    assert mcmc.post_warmup_state is not None
    warmup_samples = mcmc.get_samples()
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(warmup_samples["loc"]) == num_warmup
    assert len(samples["loc"]) == num_samples
    assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)

    mcmc.post_warmup_state = mcmc.last_state
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(samples["loc"]) == num_samples
    assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)
コード例 #2
0
def test_uniform_normal():
    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000,))
    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
    assert mcmc.post_warmup_state is not None
    warmup_samples = mcmc.get_samples()
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(warmup_samples['loc']) == num_warmup
    assert len(samples['loc']) == num_samples
    assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)

    mcmc.post_warmup_state = mcmc.last_state
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(samples['loc']) == num_samples
    assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)