Exemplo n.º 1
0
def test_prior_predictive():
    model, data, true_probs = beta_bernoulli()
    predictive_samples = Predictive(model, num_samples=100).get_samples(random.PRNGKey(1))
    assert predictive_samples.keys() == {"beta", "obs"}

    # check shapes
    assert predictive_samples["beta"].shape == (100,) + true_probs.shape
    assert predictive_samples["obs"].shape == (100,) + data.shape
Exemplo n.º 2
0
def test_log_likelihood(batch_shape):
    model, data, _ = beta_bernoulli()
    samples = Predictive(model, return_sites=["beta"], num_samples=200)(random.PRNGKey(1))
    batch_size = int(np.prod(batch_shape))
    samples = {'beta': samples['beta'][:batch_size].reshape(batch_shape + (1, -1))}

    preds = Predictive(model, samples, batch_ndims=len(batch_shape))(random.PRNGKey(2))
    loglik = log_likelihood(model, samples, data, batch_ndims=len(batch_shape))
    assert preds.keys() == {"beta_sq", "obs"}
    assert loglik.keys() == {"obs"}
    # check shapes
    assert preds["obs"].shape == batch_shape + data.shape
    assert loglik["obs"].shape == batch_shape + data.shape
    assert_allclose(loglik["obs"], dist.Bernoulli(samples["beta"]).log_prob(data))