コード例 #1
0
def test_mcmc_kernels(kernel, kwargs):
    from numpyro.contrib.tfp import mcmc
    kernel_class = getattr(mcmc, kernel)

    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample(
                'loc',
                dist.TransformedDistribution(dist.Uniform(0, 1),
                                             AffineTransform(0, alpha)))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    tfp_kernel = kernel_class(model=model, **kwargs)
    mcmc = MCMC(tfp_kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
    warmup_samples = mcmc.get_samples()
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(warmup_samples['loc']) == num_warmup
    assert len(samples['loc']) == num_samples
    assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
コード例 #2
0
ファイル: test_tfp.py プロジェクト: pyro-ppl/numpyro
def test_mcmc_kernels(kernel, kwargs):
    from numpyro.contrib.tfp import mcmc

    if ("CI" in os.environ) and kernel == "SliceSampler":
        # TODO: Look into this issue if some users are using SliceSampler
        # with NumPyro model.
        pytest.skip("SliceSampler freezes CI for unknown reason.")

    kernel_class = getattr(mcmc, kernel)

    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={"loc": TransformReparam()}):
            loc = numpyro.sample(
                "loc",
                dist.TransformedDistribution(dist.Uniform(0, 1),
                                             AffineTransform(0, alpha)),
            )
        numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    tfp_kernel = kernel_class(model=model, **kwargs)
    mcmc = MCMC(tfp_kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
    warmup_samples = mcmc.get_samples()
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(warmup_samples["loc"]) == num_warmup
    assert len(samples["loc"]) == num_samples
    assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)