def test_kl_dir_dir_one_dimensional(): n1 = dist.dirichlet(jnp.array([1.0, 0.5, 0.5])) n2 = dist.dirichlet(jnp.array([1.0, 0.5, 0.5])) n3 = dist.dirichlet(jnp.array([0.5, 2.0, 2.5])) tu.check_close(func.compute_kl_div(n1, n2), 0.) tu.check_close(func.compute_kl_div(n1, n3), 4.386292)
def test_log_prob_dir(): n1 = dist.dirichlet(jnp.array([0.4, 5.0, 15.0])) log_prob_1 = n1.log_prob(jnp.array([0.2, 0.2, 0.6])) log_prob_2 = n1.log_prob(jnp.array([0.1, 0.3, 0.6])) tu.check_close(log_prob_1, -1.257432) tu.check_close(log_prob_2, 0.7803173)
def model(key): n1 = func.sample('n1', dist.dirichlet(jnp.array([1.0, 0.5])), key) return n1
def test_kl_dir_dir_multi_dimensional(): n1 = dist.dirichlet(jnp.array([[1.0, 0.5, 0.5], [1.0, 0.5, 0.5]])) n2 = dist.dirichlet(jnp.array([[1.0, 0.5, 0.5], [0.5, 0.5, 1.0]])) tu.check_close(func.compute_kl_div(n1, n2), jnp.array([0, 0.69314706]))
def model3(key): return func.sample('n', dist.dirichlet(jnp.full((2, 3), 0.5)), key)
def model2(key): return func.sample('n', dist.dirichlet(jnp.array([2., 3., 0.5])), key)
def model1(key): return func.sample('n', dist.dirichlet(jnp.array([0.5, 0.5])), key)