def test_log_likelihood():
    model, data, _ = beta_bernoulli()
    samples = Predictive(model, return_sites=["beta"], num_samples=100).get_samples(random.PRNGKey(1))
    loglik = log_likelihood(model, samples, data)
    assert loglik.keys() == {"obs"}
    # check shapes
    assert loglik["obs"].shape == (100,) + data.shape
    assert_allclose(loglik["obs"], dist.Bernoulli(samples["beta"].reshape((100, 1, -1))).log_prob(data))
Beispiel #2
0
    def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
        assert set(gibbs_sites) == set(plate_sizes)
        u_new = {}
        for name in gibbs_sites:
            size, subsample_size = plate_sizes[name]
            rng_key, subkey, block_key = random.split(rng_key, 3)
            block_size = subsample_size // num_blocks

            chosen_block = random.randint(block_key,
                                          shape=(),
                                          minval=0,
                                          maxval=num_blocks)
            new_idx = random.randint(subkey,
                                     minval=0,
                                     maxval=size,
                                     shape=(subsample_size, ))
            block_mask = jnp.arange(
                subsample_size) // block_size == chosen_block

            u_new[name] = jnp.where(block_mask, new_idx, gibbs_sites[name])

        u_loglik = log_likelihood(_wrap_model(model),
                                  hmc_sites,
                                  *model_args,
                                  batch_ndims=0,
                                  **model_kwargs,
                                  _gibbs_sites=gibbs_sites)
        u_loglik = sum(v.sum() for v in u_loglik.values())
        u_new_loglik = log_likelihood(_wrap_model(model),
                                      hmc_sites,
                                      *model_args,
                                      batch_ndims=0,
                                      **model_kwargs,
                                      _gibbs_sites=u_new)
        u_new_loglik = sum(v.sum() for v in u_new_loglik.values())
        accept_prob = jnp.clip(jnp.exp(u_new_loglik - u_loglik), a_max=1.0)
        return cond(random.bernoulli(rng_key, accept_prob), u_new, identity,
                    gibbs_sites, identity)
Beispiel #3
0
def test_log_likelihood(batch_shape):
    model, data, _ = beta_bernoulli()
    samples = Predictive(model, return_sites=["beta"], num_samples=200)(random.PRNGKey(1))
    batch_size = int(np.prod(batch_shape))
    samples = {'beta': samples['beta'][:batch_size].reshape(batch_shape + (1, -1))}

    preds = Predictive(model, samples, batch_ndims=len(batch_shape))(random.PRNGKey(2))
    loglik = log_likelihood(model, samples, data, batch_ndims=len(batch_shape))
    assert preds.keys() == {"beta_sq", "obs"}
    assert loglik.keys() == {"obs"}
    # check shapes
    assert preds["obs"].shape == batch_shape + data.shape
    assert loglik["obs"].shape == batch_shape + data.shape
    assert_allclose(loglik["obs"], dist.Bernoulli(samples["beta"]).log_prob(data))