def test_prior_predictive(): model, data, true_probs = beta_bernoulli() predictive_samples = Predictive(model, num_samples=100).get_samples(random.PRNGKey(1)) assert predictive_samples.keys() == {"beta", "obs"} # check shapes assert predictive_samples["beta"].shape == (100,) + true_probs.shape assert predictive_samples["obs"].shape == (100,) + data.shape
def test_log_likelihood(batch_shape): model, data, _ = beta_bernoulli() samples = Predictive(model, return_sites=["beta"], num_samples=200)(random.PRNGKey(1)) batch_size = int(np.prod(batch_shape)) samples = {'beta': samples['beta'][:batch_size].reshape(batch_shape + (1, -1))} preds = Predictive(model, samples, batch_ndims=len(batch_shape))(random.PRNGKey(2)) loglik = log_likelihood(model, samples, data, batch_ndims=len(batch_shape)) assert preds.keys() == {"beta_sq", "obs"} assert loglik.keys() == {"obs"} # check shapes assert preds["obs"].shape == batch_shape + data.shape assert loglik["obs"].shape == batch_shape + data.shape assert_allclose(loglik["obs"], dist.Bernoulli(samples["beta"]).log_prob(data))