Esempio n. 1
0
def test_unnormalized_normal_chain(kernel, kwargs, num_chains):
    from numpyro.contrib.tfp import mcmc

    # TODO: remove when this issue is fixed upstream
    # https://github.com/tensorflow/probability/pull/1087
    if num_chains == 2 and kernel == "ReplicaExchangeMC":
        pytest.xfail(
            "ReplicaExchangeMC is not fully compatible with omnistaging yet.")

    kernel_class = getattr(mcmc, kernel)

    true_mean, true_std = 1., 0.5
    warmup_steps, num_samples = (1000, 8000)

    def potential_fn(z):
        return 0.5 * ((z - true_mean) / true_std)**2

    init_params = jnp.array(0.) if num_chains == 1 else jnp.array([0., 2.])
    tfp_kernel = kernel_class(potential_fn=potential_fn, **kwargs)
    mcmc = MCMC(tfp_kernel,
                warmup_steps,
                num_samples,
                num_chains=num_chains,
                progress_bar=False)
    mcmc.run(random.PRNGKey(0), init_params=init_params)
    mcmc.print_summary()
    hmc_states = mcmc.get_samples()
    assert_allclose(jnp.mean(hmc_states), true_mean, rtol=0.07)
    assert_allclose(jnp.std(hmc_states), true_std, rtol=0.07)
Esempio n. 2
0
def test_mcmc_kernels(kernel, kwargs):
    from numpyro.contrib.tfp import mcmc
    kernel_class = getattr(mcmc, kernel)

    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={'loc': TransformReparam()}):
            loc = numpyro.sample(
                'loc',
                dist.TransformedDistribution(dist.Uniform(0, 1),
                                             AffineTransform(0, alpha)))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    tfp_kernel = kernel_class(model=model, **kwargs)
    mcmc = MCMC(tfp_kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
    warmup_samples = mcmc.get_samples()
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(warmup_samples['loc']) == num_warmup
    assert len(samples['loc']) == num_samples
    assert_allclose(jnp.mean(samples['loc'], 0), true_coef, atol=0.05)
Esempio n. 3
0
def test_mcmc_kernels(kernel, kwargs):
    from numpyro.contrib.tfp import mcmc

    if ("CI" in os.environ) and kernel == "SliceSampler":
        # TODO: Look into this issue if some users are using SliceSampler
        # with NumPyro model.
        pytest.skip("SliceSampler freezes CI for unknown reason.")

    kernel_class = getattr(mcmc, kernel)

    true_coef = 0.9
    num_warmup, num_samples = 1000, 1000

    def model(data):
        alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={"loc": TransformReparam()}):
            loc = numpyro.sample(
                "loc",
                dist.TransformedDistribution(dist.Uniform(0, 1),
                                             AffineTransform(0, alpha)),
            )
        numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)

    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))
    tfp_kernel = kernel_class(model=model, **kwargs)
    mcmc = MCMC(tfp_kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.warmup(random.PRNGKey(2), data, collect_warmup=True)
    warmup_samples = mcmc.get_samples()
    mcmc.run(random.PRNGKey(3), data)
    samples = mcmc.get_samples()
    assert len(warmup_samples["loc"]) == num_warmup
    assert len(samples["loc"]) == num_samples
    assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.05)
Esempio n. 4
0
def test_mcmc_unwrapped_tfp_distributions():
    from tensorflow_probability.substrates.jax import distributions as tfd

    def model(y):
        theta = numpyro.sample("p", tfd.Beta(1, 1))

        with numpyro.plate("plate", y.size):
            numpyro.sample("y", tfd.Bernoulli(probs=theta), obs=y)

    mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000)
    mcmc.run(random.PRNGKey(0), jnp.array([0, 0, 1, 1, 1]))
    samples = mcmc.get_samples()

    assert_allclose(jnp.mean(samples["p"]), 4 / 7, atol=0.05)