def test_normal_log_prob(self):
        def f(rng):
            return random_normal(rng)

        f_lp = log_prob(f)
        self.assertEqual(f_lp(0.), bd.Normal(0., 1.).log_prob(0.))
        self.assertEqual(f_lp(1.), bd.Normal(0., 1.).log_prob(1.))
    def test_log_normal_log_prob(self):
        def f(rng):
            return np.exp(random_normal(rng))

        dist = bd.TransformedDistribution(bd.Normal(0., 1.), bb.Exp())
        f_lp = log_prob(f)
        self.assertEqual(f_lp(2.), dist.log_prob(2.))
Exemple #3
0
 def test_multiple_sample(self):
   def f(rng):
     k1, k2 = random.split(rng)
     return random_normal(k1) + random_normal(k2)
   f_lp = log_prob(f)
   with self.assertRaises(ValueError):
     f_lp(0.1)
Exemple #4
0
 def test_conditional_log(self):
   def f(rng, x):
     return random_normal(rng) + x
   f_lp = log_prob(f)
   self.assertEqual(
       f_lp(0.1, 1.0),
       bd.Normal(0., 1.).log_prob(-0.9))
    def test_log_prob_in_call(self):
        def f(rng):
            z = call(lambda k: random_normal(k, name='z'))(rng)
            return z

        f_lp = log_prob(f)
        s = f(random.PRNGKey(0))
        self.assertEqual(f_lp(s), bd.Normal(0., 1.).log_prob(s))
Exemple #6
0
 def test_latent_variable(self):
   def f(rng):
     k1, k2 = random.split(rng)
     z = random_normal(k1)
     return random_normal(k2) + z
   f_lp = log_prob(f)
   with self.assertRaises(ValueError):
     f_lp(0.1)
Exemple #7
0
 def test_unzip(self):
   def f(rng):
     k1, k2 = random.split(rng)
     z = random_normal(k1, name='z')
     return random_normal(k2, name='x') + z
   init, _ = core.unzip(f, tag=state.VARIABLE)(random.PRNGKey(0))
   f_lp = log_prob(init)
   f_lp(init(random.PRNGKey(0)))
    def test_log_prob_should_fail_inside_of_make_jaxpr(self):
        @jax.jit
        def f(rng):
            z = random_normal(rng)
            # Do something noninvertible to break the log_prob
            return z > 0

        f_lp = log_prob(f)
        # We expect the "Cannot compute" error and not another JAX error.
        with self.assertRaisesRegex(ValueError,
                                    'Cannot compute log_prob of function.'):
            f_lp(True)
Exemple #9
0
def function_log_prob(f: Program) -> LogProbFunction:
    """Registers the `log_prob` for probabilistic programs.

  See `core.interpreters.log_prob` for details of this function's
  implementation.

  Args:
    f: A probabilitic program.

  Returns:
    A function that computes the log probability of a sample from the program.
  """
    return lp.log_prob(f)
Exemple #10
0
def module_log_prob(module, *args, **kwargs):
    return log_prob.log_prob(module, *args, **kwargs)
Exemple #11
0
def module_log_prob(module, *args, **kwargs):
    return log_prob.log_prob(module, *args, **kwargs)  # pytype: disable=wrong-arg-count