Esempio n. 1
0
 def model(key):
     keys = jax.random.split(key)
     weight = func.sample('weight', dist.normal(jnp.array(0.),
                                                jnp.array(1.)), keys[0])
     measurement = func.sample('measurement',
                               dist.normal(weight, jnp.array(1.)), keys[1])
     return measurement
Esempio n. 2
0
 def model(key):
     keys = jax.random.split(key, 2)
     n1 = func.sample('n1',
                      dist.normal(jnp.zeros((2, 2)), jnp.full((2, 2), 1.)),
                      keys[0])
     n2 = func.sample('n2', dist.normal(n1, jnp.ones((2, 2))), keys[1])
     return n2
Esempio n. 3
0
def test_kl_normal_normal_multi_dimensional():
    n1 = dist.normal(jnp.array([0., 0.]), jnp.array([1., 1.]))
    n2 = dist.normal(jnp.array([0., 0.]), jnp.array([1., 1.]))
    n3 = dist.normal(jnp.array([1., 1.]), jnp.array([1., 1.]))
    n4 = dist.normal(jnp.array([1., 1.]), jnp.array([2., 2.]))

    assert jnp.all(func.compute_kl_div(n1, n2) < 1e-6)
    assert jnp.all(func.compute_kl_div(n1, n3) - 0.5 < 1e-6)
    assert jnp.all(func.compute_kl_div(n3, n4) - 0.09657359 < 1e-6)
Esempio n. 4
0
def test_kl_normal_normal_one_dimensional():
    n1 = dist.normal(jnp.array(0.), jnp.array(1.))
    n2 = dist.normal(jnp.array(0.), jnp.array(1.))
    n3 = dist.normal(jnp.array(1.), jnp.array(1.))
    n4 = dist.normal(jnp.array(1.), jnp.array(2.))

    assert func.compute_kl_div(n1, n2) < 1e-6
    assert func.compute_kl_div(n1, n3) - jnp.array(0.5) \
        < jnp.array(1e-6)
    assert func.compute_kl_div(n3, n4) \
        - jnp.array(0.09657359028) < jnp.array(1e-6)
Esempio n. 5
0
 def q(params, key):
     n1 = func.sample('n1', dist.normal(params['n1_mean'],
                                        params['n1_std']), key)
     return {'n1': n1}
Esempio n. 6
0
 def model(key):
     n1 = func.sample('n1', dist.normal(jnp.array(10.), jnp.array(10.)),
                      key)
     return n1
Esempio n. 7
0
 def model(key):
     keys = jax.random.split(key)
     n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)),
                      keys[0])
     n2 = func.sample('n2', dist.normal(n1, jnp.array(1.)), keys[1])
     return n2
Esempio n. 8
0
 def proposal(key, **current):
     n1 = func.sample('n1', dist.normal(current['n1'], jnp.array(5.)), key)
     return {'n1': n1}
Esempio n. 9
0
 def proposal(key, **current):
     n4 = func.sample('n4', dist.normal(current['n4'], jnp.array(1.)), key)
     return {'n4': n4}
Esempio n. 10
0
 def model(key):
     n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)), key)
     n2 = func.sample('n2', dist.normal(n1, jnp.array(1.)), key)
     return n2
Esempio n. 11
0
 def model2(key):
     n = func.sample('n', dist.normal(jnp.array(10.), jnp.array(1.)), key)
     return n
Esempio n. 12
0
def test_log_prob_normal():
    n = dist.normal(jnp.array(0.), jnp.array(1.))
    log_prob_mean = n.log_prob(jnp.array(0.))
    tu.check_close(log_prob_mean, -0.9189385)
    log_prob_std = n.log_prob(jnp.array(1.))
    tu.check_close(log_prob_std, -1.4189385)