コード例 #1
0
def test_kl_binomial_binomial_one_dimensional():
    n1 = dist.binomial(jnp.array(10), jnp.array(0.5))
    n2 = dist.binomial(jnp.array(5), jnp.array(0.5))
    n3 = dist.binomial(jnp.array(10), jnp.array(0.5))
    n4 = dist.binomial(jnp.array(10), jnp.array(0.1))

    with pytest.raises(ValueError):
        func.compute_kl_div(n1, n2)

    assert func.compute_kl_div(n1, n3) < jnp.array(1e-6)
    assert func.compute_kl_div(n1, n4) - 5.1082563 < jnp.array(1e-6)
コード例 #2
0
def test_kl_normal_normal_multi_dimensional():
    n1 = dist.binomial(jnp.array([10, 5]), jnp.array([0.5, 0.5]))
    n2 = dist.binomial(jnp.array([5, 5]), jnp.array([0.5, 0.5]))
    n3 = dist.binomial(jnp.array([10, 5]), jnp.array([0.5, 0.5]))
    n4 = dist.binomial(jnp.array([10, 5]), jnp.array([0.1, 0.5]))

    with pytest.raises(ValueError):
        func.compute_kl_div(n1, n2)

    assert jnp.all(jnp.abs(func.compute_kl_div(n1, n3)) < 1e-6)
    assert jnp.all(
        jnp.abs(func.compute_kl_div(n1, n4) -
                jnp.array([5.1082563, 0.0])) < 1e-6)
コード例 #3
0
 def model4(key):
     return func.sample('n', dist.binomial(jnp.array(10), jnp.array(0.5)),
                        key)
コード例 #4
0
 def model(key):
     keys = jax.random.split(key)
     n1 = func.sample('n1', dist.bernoulli(jnp.array(0.5)), keys[0])
     n2 = func.sample('n2', dist.binomial(jnp.array(1), n1), keys[1])
     return n2