def test_vmapping_distribution_reduces_to_scalar_log_prob(self): def model(key): return jax.vmap(ppl.rv(tfd.Normal(0., 1.)))(random.split(key)) out = ppl.log_prob(model)(jnp.arange(2.)) np.testing.assert_allclose( tfd.Normal(0., 1.).log_prob(jnp.arange(2.)).sum(), out)
def test_plate_reduces_over_named_axes(self): model = ppl.rv(tfd.Normal(0., 1.), plate='foo') out = jax.vmap( ppl.log_prob(model), axis_name='foo', out_axes=None)( jnp.arange(3.)) np.testing.assert_allclose( tfd.Normal(0., 1.).log_prob(jnp.arange(3.)).sum(), out)
def step(key, state): transition_key, accept_key = random.split(key) next_state = inner_step(transition_key, state) forward_transition_log_prob = ppl.log_prob(inner_step)(state, next_state) backward_transition_log_prob = ppl.log_prob(inner_step)(next_state, state) # TODO(sharadmv): add log probabilities to the state to avoid recalculation. state_log_prob = unnormalized_log_prob(state) next_state_log_prob = unnormalized_log_prob(next_state) log_unclipped_accept_prob = (next_state_log_prob + backward_transition_log_prob - state_log_prob - forward_transition_log_prob) accept_prob = np.clip(np.exp(log_unclipped_accept_prob), 0., 1.) u = primitive.tie_in(accept_prob, random.uniform(accept_key)) accept = np.log(u) < log_unclipped_accept_prob return tree_util.tree_multimap(lambda n, s: np.where(accept, n, s), next_state, state)
def test_log_prob_transformation(self, dist, args, kwargs, out, flat): del out, flat p = dist(*args, **kwargs) def sample(key): return ppl.random_variable(p)(key) self.assertEqual( p.log_prob(sample(random.PRNGKey(0))), ppl.log_prob(sample)(sample(random.PRNGKey(0))))
def test_can_map_over_batches_with_vmap_and_reduce_to_scalar_log_prob(self): def f(key, x): return ppl.rv(tfd.Normal(x, 1.))(key) def model(key, xs): return jax.vmap(f)(random.split(key), xs) out = ppl.log_prob(model)(jnp.arange(2.), 2 * jnp.arange(2.)) np.testing.assert_allclose( tfd.Normal(jnp.arange(2.), 1.).log_prob(2 * jnp.arange(2.)).sum(), out)
def inner_log_prob(state_momentum): state, momentum = state_momentum momentum_prob = ppl.log_prob(momentum_distribution)(momentum) return unnormalized_log_prob(state) + momentum_prob