示例#1
0
def test_kl_beta_beta_one_dimensional():
    n1 = dist.beta(jnp.array(1.0), jnp.array(0.5))
    n2 = dist.beta(jnp.array(1.0), jnp.array(0.5))
    n3 = dist.beta(jnp.array(2.0), jnp.array(0.5))

    tu.check_close(func.compute_kl_div(n1, n2), 0)
    tu.check_close(func.compute_kl_div(n1, n3), 0.23342943)
示例#2
0
def test_kl_dir_dir_one_dimensional():
    n1 = dist.dirichlet(jnp.array([1.0, 0.5, 0.5]))
    n2 = dist.dirichlet(jnp.array([1.0, 0.5, 0.5]))
    n3 = dist.dirichlet(jnp.array([0.5, 2.0, 2.5]))

    tu.check_close(func.compute_kl_div(n1, n2), 0.)
    tu.check_close(func.compute_kl_div(n1, n3), 4.386292)
示例#3
0
def test_kl_normal_normal_multi_dimensional():
    n1 = dist.normal(jnp.array([0., 0.]), jnp.array([1., 1.]))
    n2 = dist.normal(jnp.array([0., 0.]), jnp.array([1., 1.]))
    n3 = dist.normal(jnp.array([1., 1.]), jnp.array([1., 1.]))
    n4 = dist.normal(jnp.array([1., 1.]), jnp.array([2., 2.]))

    assert jnp.all(func.compute_kl_div(n1, n2) < 1e-6)
    assert jnp.all(func.compute_kl_div(n1, n3) - 0.5 < 1e-6)
    assert jnp.all(func.compute_kl_div(n3, n4) - 0.09657359 < 1e-6)
示例#4
0
def test_kl_binomial_binomial_one_dimensional():
    n1 = dist.binomial(jnp.array(10), jnp.array(0.5))
    n2 = dist.binomial(jnp.array(5), jnp.array(0.5))
    n3 = dist.binomial(jnp.array(10), jnp.array(0.5))
    n4 = dist.binomial(jnp.array(10), jnp.array(0.1))

    with pytest.raises(ValueError):
        func.compute_kl_div(n1, n2)

    assert func.compute_kl_div(n1, n3) < jnp.array(1e-6)
    assert func.compute_kl_div(n1, n4) - 5.1082563 < jnp.array(1e-6)
示例#5
0
def test_kl_normal_normal_one_dimensional():
    n1 = dist.normal(jnp.array(0.), jnp.array(1.))
    n2 = dist.normal(jnp.array(0.), jnp.array(1.))
    n3 = dist.normal(jnp.array(1.), jnp.array(1.))
    n4 = dist.normal(jnp.array(1.), jnp.array(2.))

    assert func.compute_kl_div(n1, n2) < 1e-6
    assert func.compute_kl_div(n1, n3) - jnp.array(0.5) \
        < jnp.array(1e-6)
    assert func.compute_kl_div(n3, n4) \
        - jnp.array(0.09657359028) < jnp.array(1e-6)
示例#6
0
def test_kl_normal_normal_multi_dimensional():
    n1 = dist.binomial(jnp.array([10, 5]), jnp.array([0.5, 0.5]))
    n2 = dist.binomial(jnp.array([5, 5]), jnp.array([0.5, 0.5]))
    n3 = dist.binomial(jnp.array([10, 5]), jnp.array([0.5, 0.5]))
    n4 = dist.binomial(jnp.array([10, 5]), jnp.array([0.1, 0.5]))

    with pytest.raises(ValueError):
        func.compute_kl_div(n1, n2)

    assert jnp.all(jnp.abs(func.compute_kl_div(n1, n3)) < 1e-6)
    assert jnp.all(
        jnp.abs(func.compute_kl_div(n1, n4) -
                jnp.array([5.1082563, 0.0])) < 1e-6)
示例#7
0
def test_kl_multinomial_multinomial_multi_dimensional():
    n1 = dist.multinomial(jnp.array([10, 5]),
                          jnp.array([[0.5, 0.5], [0.5, 0.5]]))
    n2 = dist.multinomial(jnp.array([5, 5]), jnp.array([[0.5, 0.5], [0.5,
                                                                     0.5]]))
    n3 = dist.multinomial(jnp.array([5, 5]),
                          jnp.array([[0.1, 0.8, 0.1], [0.1, 0.8, 0.1]]))
    n4 = dist.multinomial(jnp.array([10, 5]),
                          jnp.array([[0.1, 0.9], [0.1, 0.9]]))

    with pytest.raises(ValueError):
        func.compute_kl_div(n1, n2)

    with pytest.raises(ValueError):
        func.compute_kl_div(n1, n3)

    tu.check_close(func.compute_kl_div(n1, n4),
                   jnp.array([5.1082563, 2.5541282]))
示例#8
0
def test_kl_uniform_uniform():
    n1 = dist.uniform()
    n2 = dist.uniform()

    assert func.compute_kl_div(n1, n2) == 0
示例#9
0
def test_kl_dir_dir_multi_dimensional():
    n1 = dist.dirichlet(jnp.array([[1.0, 0.5, 0.5], [1.0, 0.5, 0.5]]))
    n2 = dist.dirichlet(jnp.array([[1.0, 0.5, 0.5], [0.5, 0.5, 1.0]]))

    tu.check_close(func.compute_kl_div(n1, n2), jnp.array([0, 0.69314706]))
示例#10
0
def test_kl_beta_beta_multi_dimensional():
    n1 = dist.beta(jnp.array([1.0, 2.0]), jnp.array([0.5, 0.5]))
    n2 = dist.beta(jnp.array([2.0, 1.0]), jnp.array([0.5, 0.5]))

    tu.check_close(func.compute_kl_div(n1, n2),
                   jnp.array([0.23342943, 0.23268747]))