コード例 #1
0
  def test_vmapping_distribution_reduces_to_scalar_log_prob(self):

    def model(key):
      return jax.vmap(ppl.rv(tfd.Normal(0., 1.)))(random.split(key))

    out = ppl.log_prob(model)(jnp.arange(2.))
    np.testing.assert_allclose(
        tfd.Normal(0., 1.).log_prob(jnp.arange(2.)).sum(), out)
コード例 #2
0
  def test_plate_reduces_over_named_axes(self):

    model = ppl.rv(tfd.Normal(0., 1.), plate='foo')
    out = jax.vmap(
        ppl.log_prob(model), axis_name='foo', out_axes=None)(
            jnp.arange(3.))
    np.testing.assert_allclose(
        tfd.Normal(0., 1.).log_prob(jnp.arange(3.)).sum(), out)
コード例 #3
0
ファイル: kernels.py プロジェクト: tensorflow/probability
 def step(key, state):
     transition_key, accept_key = random.split(key)
     next_state = inner_step(transition_key, state)
     forward_transition_log_prob = ppl.log_prob(inner_step)(state,
                                                            next_state)
     backward_transition_log_prob = ppl.log_prob(inner_step)(next_state,
                                                             state)
     # TODO(sharadmv): add log probabilities to the state to avoid recalculation.
     state_log_prob = unnormalized_log_prob(state)
     next_state_log_prob = unnormalized_log_prob(next_state)
     log_unclipped_accept_prob = (next_state_log_prob +
                                  backward_transition_log_prob -
                                  state_log_prob -
                                  forward_transition_log_prob)
     accept_prob = np.clip(np.exp(log_unclipped_accept_prob), 0., 1.)
     u = primitive.tie_in(accept_prob, random.uniform(accept_key))
     accept = np.log(u) < log_unclipped_accept_prob
     return tree_util.tree_multimap(lambda n, s: np.where(accept, n, s),
                                    next_state, state)
コード例 #4
0
  def test_log_prob_transformation(self, dist, args, kwargs, out, flat):
    del out, flat
    p = dist(*args, **kwargs)

    def sample(key):
      return ppl.random_variable(p)(key)

    self.assertEqual(
        p.log_prob(sample(random.PRNGKey(0))),
        ppl.log_prob(sample)(sample(random.PRNGKey(0))))
コード例 #5
0
  def test_can_map_over_batches_with_vmap_and_reduce_to_scalar_log_prob(self):

    def f(key, x):
      return ppl.rv(tfd.Normal(x, 1.))(key)

    def model(key, xs):
      return jax.vmap(f)(random.split(key), xs)

    out = ppl.log_prob(model)(jnp.arange(2.), 2 * jnp.arange(2.))
    np.testing.assert_allclose(
        tfd.Normal(jnp.arange(2.), 1.).log_prob(2 * jnp.arange(2.)).sum(), out)
コード例 #6
0
 def inner_log_prob(state_momentum):
     state, momentum = state_momentum
     momentum_prob = ppl.log_prob(momentum_distribution)(momentum)
     return unnormalized_log_prob(state) + momentum_prob