def test_predictive_with_guide(): data = np.array([1] * 8 + [0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1., 1.)) with numpyro.plate("plate", 10): numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) svi = SVI(model, guide, optim.Adam(0.1), ELBO()) svi_state = svi.init(random.PRNGKey(1), data) def body_fn(i, val): svi_state, _ = svi.update(val, data) return svi_state svi_state = lax.fori_loop(0, 1000, body_fn, svi_state) params = svi.get_params(svi_state) predictive = Predictive(model, guide=guide, params=params, num_samples=1000) obs_pred = predictive.get_samples(random.PRNGKey(2), data=None)["obs"] assert_allclose(np.mean(obs_pred), 0.8, atol=0.05)
def test_predictive(parallel): model, data, true_probs = beta_bernoulli() mcmc = MCMC(NUTS(model), num_warmup=100, num_samples=100) mcmc.run(random.PRNGKey(0), data) samples = mcmc.get_samples() predictive = Predictive(model, samples, parallel=parallel) predictive_samples = predictive.get_samples(random.PRNGKey(1)) assert predictive_samples.keys() == {"obs"} predictive.return_sites = ["beta", "obs"] predictive_samples = predictive.get_samples(random.PRNGKey(1)) # check shapes assert predictive_samples["beta"].shape == (100,) + true_probs.shape assert predictive_samples["obs"].shape == (100,) + data.shape # check sample mean assert_allclose(predictive_samples["obs"].reshape((-1,) + true_probs.shape).mean(0), true_probs, rtol=0.1)