Esempio n. 1
0
def test_linear_model_sigma(kernel_cls,
                            N=90,
                            P=40,
                            sigma=0.07,
                            warmup_steps=500,
                            num_samples=500):
    np.random.seed(1)
    X = np.random.randn(N * P).reshape((N, P))
    XX = np.matmul(np.transpose(X), X)
    Y = X[:, 0] + sigma * np.random.randn(N)
    XY = np.sum(X * Y[:, None], axis=0)

    def model(X, Y):
        N, P = X.shape

        sigma = numpyro.sample("sigma", dist.HalfCauchy(1.0))
        beta = numpyro.sample("beta", dist.Normal(jnp.zeros(P), jnp.ones(P)))
        mean = jnp.sum(beta * X, axis=-1)

        numpyro.sample("obs", dist.Normal(mean, sigma), obs=Y)

    gibbs_fn = partial(_linear_regression_gibbs_fn, X, XX, XY, Y)

    hmc_kernel = kernel_cls(model)
    kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=['beta'])
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)

    mcmc.run(random.PRNGKey(0), X, Y)

    beta_mean = np.mean(mcmc.get_samples()['beta'], axis=0)
    assert_allclose(beta_mean, np.array([1.0] + [0.0] * (P - 1)), atol=0.05)

    sigma_mean = np.mean(mcmc.get_samples()['sigma'], axis=0)
    assert_allclose(sigma_mean, sigma, atol=0.25)
Esempio n. 2
0
def test_linear_model_log_sigma(
    kernel_cls, N=100, P=50, sigma=0.11, num_warmup=500, num_samples=500
):
    np.random.seed(0)
    X = np.random.randn(N * P).reshape((N, P))
    XX = np.matmul(np.transpose(X), X)
    Y = X[:, 0] + sigma * np.random.randn(N)
    XY = np.sum(X * Y[:, None], axis=0)

    def model(X, Y):
        N, P = X.shape

        log_sigma = numpyro.sample("log_sigma", dist.Normal(1.0))
        sigma = jnp.exp(log_sigma)
        beta = numpyro.sample("beta", dist.Normal(jnp.zeros(P), jnp.ones(P)))
        mean = jnp.sum(beta * X, axis=-1)
        numpyro.deterministic("mean", mean)

        numpyro.sample("obs", dist.Normal(mean, sigma), obs=Y)

    gibbs_fn = partial(_linear_regression_gibbs_fn, X, XX, XY, Y)

    hmc_kernel = kernel_cls(model)
    kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=["beta"])
    mcmc = MCMC(
        kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False
    )

    mcmc.run(random.PRNGKey(0), X, Y)

    beta_mean = np.mean(mcmc.get_samples()["beta"], axis=0)
    assert_allclose(beta_mean, np.array([1.0] + [0.0] * (P - 1)), atol=0.05)

    sigma_mean = np.exp(np.mean(mcmc.get_samples()["log_sigma"], axis=0))
    assert_allclose(sigma_mean, sigma, atol=0.25)
Esempio n. 3
0
def test_gaussian_model(kernel_cls, D=2, num_warmup=5000, num_samples=5000):
    np.random.seed(0)
    cov = np.random.randn(4 * D * D).reshape((2 * D, 2 * D))
    cov = jnp.matmul(jnp.transpose(cov), cov) + 0.25 * jnp.eye(2 * D)

    cov00 = cov[:D, :D]
    cov01 = cov[:D, D:]
    cov10 = cov[D:, :D]
    cov11 = cov[D:, D:]

    cov_01_cov11_inv = jnp.matmul(cov01, inv(cov11))
    cov_10_cov00_inv = jnp.matmul(cov10, inv(cov00))

    posterior_cov0 = cov00 - jnp.matmul(cov_01_cov11_inv, cov10)
    posterior_cov1 = cov11 - jnp.matmul(cov_10_cov00_inv, cov01)

    # we consider a model in which (x0, x1) ~ MVN(0, cov)

    def gaussian_gibbs_fn(rng_key, hmc_sites, gibbs_sites):
        x1 = hmc_sites["x1"]
        posterior_loc0 = jnp.matmul(cov_01_cov11_inv, x1)
        x0_proposal = dist.MultivariateNormal(
            loc=posterior_loc0,
            covariance_matrix=posterior_cov0).sample(rng_key)
        return {"x0": x0_proposal}

    def model():
        x0 = numpyro.sample(
            "x0",
            dist.MultivariateNormal(loc=jnp.zeros(D), covariance_matrix=cov00))
        posterior_loc1 = jnp.matmul(cov_10_cov00_inv, x0)
        numpyro.sample(
            "x1",
            dist.MultivariateNormal(loc=posterior_loc1,
                                    covariance_matrix=posterior_cov1),
        )

    hmc_kernel = kernel_cls(model, dense_mass=True)
    kernel = HMCGibbs(hmc_kernel,
                      gibbs_fn=gaussian_gibbs_fn,
                      gibbs_sites=["x0"])
    mcmc = MCMC(kernel,
                num_warmup=num_warmup,
                num_samples=num_samples,
                progress_bar=False)

    mcmc.run(random.PRNGKey(0))

    x0_mean = np.mean(mcmc.get_samples()["x0"], axis=0)
    x1_mean = np.mean(mcmc.get_samples()["x1"], axis=0)

    x0_std = np.std(mcmc.get_samples()["x0"], axis=0)
    x1_std = np.std(mcmc.get_samples()["x1"], axis=0)

    assert_allclose(x0_mean, np.zeros(D), atol=0.2)
    assert_allclose(x1_mean, np.zeros(D), atol=0.2)

    assert_allclose(x0_std, np.sqrt(np.diagonal(cov00)), rtol=0.05)
    assert_allclose(x1_std, np.sqrt(np.diagonal(cov11)), rtol=0.1)
Esempio n. 4
0
def test_discrete_gibbs_enum():
    def model():
        numpyro.sample("x", dist.Bernoulli(0.7))
        y = numpyro.sample("y", dist.Binomial(10, 0.3))
        numpyro.deterministic("y2", y**2)

    kernel = HMCGibbs(NUTS(model), discrete_gibbs_fn(model), gibbs_sites=["y"])
    mcmc = MCMC(kernel, 1000, 10000, progress_bar=False)
    mcmc.run(random.PRNGKey(0))
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["y"], 0), 0.3 * 10, atol=0.1)
Esempio n. 5
0
def test_discrete_gibbs_bernoulli(random_walk, modified):
    def model():
        numpyro.sample("c", dist.Bernoulli(0.8))

    gibbs_fn = discrete_gibbs_fn(model,
                                 random_walk=random_walk,
                                 modified=modified)
    kernel = HMCGibbs(NUTS(model), gibbs_fn, gibbs_sites=["c"])
    mcmc = MCMC(kernel, 1000, 200000, progress_bar=False)
    mcmc.run(random.PRNGKey(0))
    samples = mcmc.get_samples()["c"]
    assert_allclose(jnp.mean(samples), 0.8, atol=0.05)
Esempio n. 6
0
def test_discrete_gibbs_multiple_sites():
    def model():
        numpyro.sample("x", dist.Bernoulli(0.7).expand([3]))
        numpyro.sample("y", dist.Binomial(10, 0.3))

    kernel = HMCGibbs(NUTS(model),
                      discrete_gibbs_fn(model),
                      gibbs_sites=["x", "y"])
    mcmc = MCMC(kernel, 1000, 10000, progress_bar=False)
    mcmc.run(random.PRNGKey(0))
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["x"], 0), 0.7 * jnp.ones(3), atol=0.01)
    assert_allclose(jnp.mean(samples["y"], 0), 0.3 * 10, atol=0.1)
Esempio n. 7
0
def test_discrete_gibbs_gmm_1d(modified):
    def model(probs, locs):
        c = numpyro.sample("c", dist.Categorical(probs))
        numpyro.sample("x", dist.Normal(locs[c], 0.5))

    probs = jnp.array([0.15, 0.3, 0.3, 0.25])
    locs = jnp.array([-2, 0, 2, 4])
    gibbs_fn = discrete_gibbs_fn(model, (probs, locs), modified=modified)
    kernel = HMCGibbs(NUTS(model), gibbs_fn, gibbs_sites=["c"])
    mcmc = MCMC(kernel, 1000, 200000, progress_bar=False)
    mcmc.run(random.PRNGKey(0), probs, locs)
    samples = mcmc.get_samples()
    assert_allclose(jnp.mean(samples["x"]), 1.3, atol=0.1)
    assert_allclose(jnp.var(samples["x"]), 4.36, atol=0.1)
    assert_allclose(jnp.mean(samples["c"]), 1.65, atol=0.1)
    assert_allclose(jnp.var(samples["c"]), 1.03, atol=0.1)
Esempio n. 8
0
def sample_posterior_gibbs(rng_key: random.PRNGKey,
                           model,
                           data: np.ndarray,
                           Nsamples: int = 1000,
                           alpha: float = 1,
                           sigma: float = 0,
                           T: int = 10,
                           gibbs_fn=None,
                           gibbs_sites=None):
    assert gibbs_fn is not None
    assert gibbs_sites is not None

    Npoints = len(data)

    inner_kernel = NUTS(model)
    kernel = HMCGibbs(inner_kernel, gibbs_fn=gibbs_fn, gibbs_sites=gibbs_sites)
    mcmc = MCMC(kernel, num_samples=Nsamples, num_warmup=NUM_WARMUP)
    mcmc.run(rng_key, data=data, alpha=alpha, sigma=sigma, T=T)
    samples = mcmc.get_samples()

    z = samples['z']
    assert z.shape == (Nsamples, Npoints)

    return z