コード例 #1
0
def test_kl_dir_dir_one_dimensional():
    n1 = dist.dirichlet(jnp.array([1.0, 0.5, 0.5]))
    n2 = dist.dirichlet(jnp.array([1.0, 0.5, 0.5]))
    n3 = dist.dirichlet(jnp.array([0.5, 2.0, 2.5]))

    tu.check_close(func.compute_kl_div(n1, n2), 0.)
    tu.check_close(func.compute_kl_div(n1, n3), 4.386292)
コード例 #2
0
def test_log_prob_dir():
    n1 = dist.dirichlet(jnp.array([0.4, 5.0, 15.0]))
    log_prob_1 = n1.log_prob(jnp.array([0.2, 0.2, 0.6]))
    log_prob_2 = n1.log_prob(jnp.array([0.1, 0.3, 0.6]))
    tu.check_close(log_prob_1, -1.257432)
    tu.check_close(log_prob_2, 0.7803173)
コード例 #3
0
 def model(key):
     n1 = func.sample('n1', dist.dirichlet(jnp.array([1.0, 0.5])), key)
     return n1
コード例 #4
0
def test_kl_dir_dir_multi_dimensional():
    n1 = dist.dirichlet(jnp.array([[1.0, 0.5, 0.5], [1.0, 0.5, 0.5]]))
    n2 = dist.dirichlet(jnp.array([[1.0, 0.5, 0.5], [0.5, 0.5, 1.0]]))

    tu.check_close(func.compute_kl_div(n1, n2), jnp.array([0, 0.69314706]))
コード例 #5
0
 def model3(key):
     return func.sample('n', dist.dirichlet(jnp.full((2, 3), 0.5)), key)
コード例 #6
0
 def model2(key):
     return func.sample('n', dist.dirichlet(jnp.array([2., 3., 0.5])), key)
コード例 #7
0
 def model1(key):
     return func.sample('n', dist.dirichlet(jnp.array([0.5, 0.5])), key)