def test_normal_log_prob(self): def f(rng): return random_normal(rng) f_lp = log_prob(f) self.assertEqual(f_lp(0.), bd.Normal(0., 1.).log_prob(0.)) self.assertEqual(f_lp(1.), bd.Normal(0., 1.).log_prob(1.))
def test_log_normal_log_prob(self): def f(rng): return np.exp(random_normal(rng)) dist = bd.TransformedDistribution(bd.Normal(0., 1.), bb.Exp()) f_lp = log_prob(f) self.assertEqual(f_lp(2.), dist.log_prob(2.))
def test_multiple_sample(self): def f(rng): k1, k2 = random.split(rng) return random_normal(k1) + random_normal(k2) f_lp = log_prob(f) with self.assertRaises(ValueError): f_lp(0.1)
def test_conditional_log(self): def f(rng, x): return random_normal(rng) + x f_lp = log_prob(f) self.assertEqual( f_lp(0.1, 1.0), bd.Normal(0., 1.).log_prob(-0.9))
def test_log_prob_in_call(self): def f(rng): z = call(lambda k: random_normal(k, name='z'))(rng) return z f_lp = log_prob(f) s = f(random.PRNGKey(0)) self.assertEqual(f_lp(s), bd.Normal(0., 1.).log_prob(s))
def test_latent_variable(self): def f(rng): k1, k2 = random.split(rng) z = random_normal(k1) return random_normal(k2) + z f_lp = log_prob(f) with self.assertRaises(ValueError): f_lp(0.1)
def test_unzip(self): def f(rng): k1, k2 = random.split(rng) z = random_normal(k1, name='z') return random_normal(k2, name='x') + z init, _ = core.unzip(f, tag=state.VARIABLE)(random.PRNGKey(0)) f_lp = log_prob(init) f_lp(init(random.PRNGKey(0)))
def test_log_prob_should_fail_inside_of_make_jaxpr(self): @jax.jit def f(rng): z = random_normal(rng) # Do something noninvertible to break the log_prob return z > 0 f_lp = log_prob(f) # We expect the "Cannot compute" error and not another JAX error. with self.assertRaisesRegex(ValueError, 'Cannot compute log_prob of function.'): f_lp(True)
def function_log_prob(f: Program) -> LogProbFunction: """Registers the `log_prob` for probabilistic programs. See `core.interpreters.log_prob` for details of this function's implementation. Args: f: A probabilitic program. Returns: A function that computes the log probability of a sample from the program. """ return lp.log_prob(f)
def module_log_prob(module, *args, **kwargs): return log_prob.log_prob(module, *args, **kwargs)
def module_log_prob(module, *args, **kwargs): return log_prob.log_prob(module, *args, **kwargs) # pytype: disable=wrong-arg-count