def test_log_likelihood(): model, data, _ = beta_bernoulli() samples = Predictive(model, return_sites=["beta"], num_samples=100).get_samples(random.PRNGKey(1)) loglik = log_likelihood(model, samples, data) assert loglik.keys() == {"obs"} # check shapes assert loglik["obs"].shape == (100,) + data.shape assert_allclose(loglik["obs"], dist.Bernoulli(samples["beta"].reshape((100, 1, -1))).log_prob(data))
def gibbs_fn(rng_key, gibbs_sites, hmc_sites): assert set(gibbs_sites) == set(plate_sizes) u_new = {} for name in gibbs_sites: size, subsample_size = plate_sizes[name] rng_key, subkey, block_key = random.split(rng_key, 3) block_size = subsample_size // num_blocks chosen_block = random.randint(block_key, shape=(), minval=0, maxval=num_blocks) new_idx = random.randint(subkey, minval=0, maxval=size, shape=(subsample_size, )) block_mask = jnp.arange( subsample_size) // block_size == chosen_block u_new[name] = jnp.where(block_mask, new_idx, gibbs_sites[name]) u_loglik = log_likelihood(_wrap_model(model), hmc_sites, *model_args, batch_ndims=0, **model_kwargs, _gibbs_sites=gibbs_sites) u_loglik = sum(v.sum() for v in u_loglik.values()) u_new_loglik = log_likelihood(_wrap_model(model), hmc_sites, *model_args, batch_ndims=0, **model_kwargs, _gibbs_sites=u_new) u_new_loglik = sum(v.sum() for v in u_new_loglik.values()) accept_prob = jnp.clip(jnp.exp(u_new_loglik - u_loglik), a_max=1.0) return cond(random.bernoulli(rng_key, accept_prob), u_new, identity, gibbs_sites, identity)
def test_log_likelihood(batch_shape): model, data, _ = beta_bernoulli() samples = Predictive(model, return_sites=["beta"], num_samples=200)(random.PRNGKey(1)) batch_size = int(np.prod(batch_shape)) samples = {'beta': samples['beta'][:batch_size].reshape(batch_shape + (1, -1))} preds = Predictive(model, samples, batch_ndims=len(batch_shape))(random.PRNGKey(2)) loglik = log_likelihood(model, samples, data, batch_ndims=len(batch_shape)) assert preds.keys() == {"beta_sq", "obs"} assert loglik.keys() == {"obs"} # check shapes assert preds["obs"].shape == batch_shape + data.shape assert loglik["obs"].shape == batch_shape + data.shape assert_allclose(loglik["obs"], dist.Bernoulli(samples["beta"]).log_prob(data))