Beispiel #1
0
def test_kl_beta_beta_one_dimensional():
    n1 = dist.beta(jnp.array(1.0), jnp.array(0.5))
    n2 = dist.beta(jnp.array(1.0), jnp.array(0.5))
    n3 = dist.beta(jnp.array(2.0), jnp.array(0.5))

    tu.check_close(func.compute_kl_div(n1, n2), 0)
    tu.check_close(func.compute_kl_div(n1, n3), 0.23342943)
Beispiel #2
0
def test_log_prob_beta():
    n1 = dist.beta(jnp.array(0.5), jnp.array(0.5))
    log_prob_1 = n1.log_prob(jnp.array(0.1))
    log_prob_2 = n1.log_prob(jnp.array(0.9))
    tu.check_close(log_prob_1, 0.05924129)
    tu.check_close(log_prob_2, 0.05924118)
Beispiel #3
0
 def model(key):
     n1 = func.sample('n1', dist.beta(jnp.array(0.5), jnp.array(0.5)), key)
     return n1
Beispiel #4
0
def test_kl_beta_beta_multi_dimensional():
    n1 = dist.beta(jnp.array([1.0, 2.0]), jnp.array([0.5, 0.5]))
    n2 = dist.beta(jnp.array([2.0, 1.0]), jnp.array([0.5, 0.5]))

    tu.check_close(func.compute_kl_div(n1, n2),
                   jnp.array([0.23342943, 0.23268747]))
Beispiel #5
0
 def model3(key):
     return func.sample(
         'n', dist.beta(jnp.full((2, 2), 0.5), jnp.full((2, 2), 0.5)), key)
Beispiel #6
0
 def model2(key):
     return func.sample('n', dist.beta(jnp.array(2.), jnp.array(5.)), key)
Beispiel #7
0
 def model1(key):
     return func.sample('n', dist.beta(jnp.array(0.5), jnp.array(0.5)), key)
Beispiel #8
0
 def proposal(key, **current):  # use prior as proposal
     n1 = func.sample('n1', dist.beta(jnp.array(0.5), jnp.array(0.5)), key)
     return {'n1': n1}
Beispiel #9
0
 def model(key):
     keys = jax.random.split(key, 2)
     n1 = func.sample('n1', dist.beta(jnp.array(0.5), jnp.array(0.5)),
                      keys[0])
     n2 = func.sample('n2', dist.bernoulli(n1), keys[1])
     return n2