Example #1
0
def test_dirichlet_categorical_x64(kernel_cls, dense_mass):
    num_warmup, num_samples = 100, 20000

    def model(data):
        concentration = jnp.array([1.0, 1.0, 1.0])
        p_latent = numpyro.sample("p_latent", dist.Dirichlet(concentration))
        numpyro.sample("obs", dist.Categorical(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.1, 0.6, 0.3])
    data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000, ))
    if kernel_cls is BarkerMH:
        kernel = BarkerMH(model=model, dense_mass=dense_mass)
    else:
        kernel = kernel_cls(model,
                            trajectory_length=1.0,
                            dense_mass=dense_mass)
    mcmc = MCMC(kernel,
                num_warmup=num_warmup,
                num_samples=num_samples,
                progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.02)

    if "JAX_ENABLE_X64" in os.environ:
        assert samples["p_latent"].dtype == jnp.float64
Example #2
0
def test_dense_mass(kernel_cls, rho):
    warmup_steps, num_samples = 20000, 10000

    true_cov = jnp.array([[10.0, rho], [rho, 0.1]])

    def model():
        numpyro.sample(
            "x", dist.MultivariateNormal(jnp.zeros(2), covariance_matrix=true_cov)
        )

    if kernel_cls is HMC or kernel_cls is NUTS:
        kernel = kernel_cls(model, trajectory_length=2.0, dense_mass=True)
    elif kernel_cls is BarkerMH:
        kernel = BarkerMH(model, dense_mass=True)

    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(0))

    mass_matrix_sqrt = mcmc.last_state.adapt_state.mass_matrix_sqrt
    if kernel_cls is HMC or kernel_cls is NUTS:
        mass_matrix_sqrt = mass_matrix_sqrt[("x",)]
    mass_matrix = jnp.matmul(mass_matrix_sqrt, jnp.transpose(mass_matrix_sqrt))
    estimated_cov = jnp.linalg.inv(mass_matrix)
    assert_allclose(estimated_cov, true_cov, rtol=0.10)

    samples = mcmc.get_samples()["x"]
    assert_allclose(jnp.mean(samples[:, 0]), jnp.array(0.0), atol=0.50)
    assert_allclose(jnp.mean(samples[:, 1]), jnp.array(0.0), atol=0.05)
    assert_allclose(jnp.mean(samples[:, 0] * samples[:, 1]), jnp.array(rho), atol=0.20)
    assert_allclose(jnp.var(samples, axis=0), jnp.array([10.0, 0.1]), rtol=0.20)
Example #3
0
def test_beta_bernoulli_x64(kernel_cls):
    num_warmup, num_samples = (100000, 100000) if kernel_cls is SA else (500,
                                                                         20000)

    def model(data):
        alpha = jnp.array([1.1, 1.1])
        beta = jnp.array([1.1, 1.1])
        p_latent = numpyro.sample("p_latent", dist.Beta(alpha, beta))
        numpyro.sample("obs", dist.Bernoulli(p_latent), obs=data)
        return p_latent

    true_probs = jnp.array([0.9, 0.1])
    data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, ))
    if kernel_cls is SA:
        kernel = SA(model=model)
    elif kernel_cls is BarkerMH:
        kernel = BarkerMH(model=model)
    else:
        kernel = kernel_cls(model=model, trajectory_length=0.1)
    mcmc = MCMC(kernel,
                num_warmup=num_warmup,
                num_samples=num_samples,
                progress_bar=False)
    mcmc.run(random.PRNGKey(2), data)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["p_latent"], 0), true_probs, atol=0.05)

    if "JAX_ENABLE_X64" in os.environ:
        assert samples["p_latent"].dtype == jnp.float64
Example #4
0
def test_logistic_regression_x64(kernel_cls):
    N, dim = 3000, 3
    if kernel_cls is SA:
        warmup_steps, num_samples = (100000, 100000)
    elif kernel_cls is BarkerMH:
        warmup_steps, num_samples = (2000, 12000)
    else:
        warmup_steps, num_samples = (1000, 8000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1., dim + 1.)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(labels):
        coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        logits = numpyro.deterministic('logits', jnp.sum(coefs * data, axis=-1))
        return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)

    if kernel_cls is SA:
        kernel = SA(model=model, adapt_state_size=9)
    elif kernel_cls is BarkerMH:
        kernel = BarkerMH(model=model)
    else:
        kernel = kernel_cls(model=model, trajectory_length=8, find_heuristic_step_size=True)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert samples['logits'].shape == (num_samples, N)
    # those coefficients are found by doing MAP inference using AutoDelta
    expected_coefs = jnp.array([0.97, 2.05, 3.18])
    assert_allclose(jnp.mean(samples['coefs'], 0), expected_coefs, atol=0.1)

    if 'JAX_ENABLE_X64' in os.environ:
        assert samples['coefs'].dtype == jnp.float64