def test_kl_beta_beta_one_dimensional(): n1 = dist.beta(jnp.array(1.0), jnp.array(0.5)) n2 = dist.beta(jnp.array(1.0), jnp.array(0.5)) n3 = dist.beta(jnp.array(2.0), jnp.array(0.5)) tu.check_close(func.compute_kl_div(n1, n2), 0) tu.check_close(func.compute_kl_div(n1, n3), 0.23342943)
def test_kl_dir_dir_one_dimensional(): n1 = dist.dirichlet(jnp.array([1.0, 0.5, 0.5])) n2 = dist.dirichlet(jnp.array([1.0, 0.5, 0.5])) n3 = dist.dirichlet(jnp.array([0.5, 2.0, 2.5])) tu.check_close(func.compute_kl_div(n1, n2), 0.) tu.check_close(func.compute_kl_div(n1, n3), 4.386292)
def test_kl_normal_normal_multi_dimensional(): n1 = dist.normal(jnp.array([0., 0.]), jnp.array([1., 1.])) n2 = dist.normal(jnp.array([0., 0.]), jnp.array([1., 1.])) n3 = dist.normal(jnp.array([1., 1.]), jnp.array([1., 1.])) n4 = dist.normal(jnp.array([1., 1.]), jnp.array([2., 2.])) assert jnp.all(func.compute_kl_div(n1, n2) < 1e-6) assert jnp.all(func.compute_kl_div(n1, n3) - 0.5 < 1e-6) assert jnp.all(func.compute_kl_div(n3, n4) - 0.09657359 < 1e-6)
def test_kl_binomial_binomial_one_dimensional(): n1 = dist.binomial(jnp.array(10), jnp.array(0.5)) n2 = dist.binomial(jnp.array(5), jnp.array(0.5)) n3 = dist.binomial(jnp.array(10), jnp.array(0.5)) n4 = dist.binomial(jnp.array(10), jnp.array(0.1)) with pytest.raises(ValueError): func.compute_kl_div(n1, n2) assert func.compute_kl_div(n1, n3) < jnp.array(1e-6) assert func.compute_kl_div(n1, n4) - 5.1082563 < jnp.array(1e-6)
def test_kl_normal_normal_one_dimensional(): n1 = dist.normal(jnp.array(0.), jnp.array(1.)) n2 = dist.normal(jnp.array(0.), jnp.array(1.)) n3 = dist.normal(jnp.array(1.), jnp.array(1.)) n4 = dist.normal(jnp.array(1.), jnp.array(2.)) assert func.compute_kl_div(n1, n2) < 1e-6 assert func.compute_kl_div(n1, n3) - jnp.array(0.5) \ < jnp.array(1e-6) assert func.compute_kl_div(n3, n4) \ - jnp.array(0.09657359028) < jnp.array(1e-6)
def test_kl_normal_normal_multi_dimensional(): n1 = dist.binomial(jnp.array([10, 5]), jnp.array([0.5, 0.5])) n2 = dist.binomial(jnp.array([5, 5]), jnp.array([0.5, 0.5])) n3 = dist.binomial(jnp.array([10, 5]), jnp.array([0.5, 0.5])) n4 = dist.binomial(jnp.array([10, 5]), jnp.array([0.1, 0.5])) with pytest.raises(ValueError): func.compute_kl_div(n1, n2) assert jnp.all(jnp.abs(func.compute_kl_div(n1, n3)) < 1e-6) assert jnp.all( jnp.abs(func.compute_kl_div(n1, n4) - jnp.array([5.1082563, 0.0])) < 1e-6)
def test_kl_multinomial_multinomial_multi_dimensional(): n1 = dist.multinomial(jnp.array([10, 5]), jnp.array([[0.5, 0.5], [0.5, 0.5]])) n2 = dist.multinomial(jnp.array([5, 5]), jnp.array([[0.5, 0.5], [0.5, 0.5]])) n3 = dist.multinomial(jnp.array([5, 5]), jnp.array([[0.1, 0.8, 0.1], [0.1, 0.8, 0.1]])) n4 = dist.multinomial(jnp.array([10, 5]), jnp.array([[0.1, 0.9], [0.1, 0.9]])) with pytest.raises(ValueError): func.compute_kl_div(n1, n2) with pytest.raises(ValueError): func.compute_kl_div(n1, n3) tu.check_close(func.compute_kl_div(n1, n4), jnp.array([5.1082563, 2.5541282]))
def test_kl_uniform_uniform(): n1 = dist.uniform() n2 = dist.uniform() assert func.compute_kl_div(n1, n2) == 0
def test_kl_dir_dir_multi_dimensional(): n1 = dist.dirichlet(jnp.array([[1.0, 0.5, 0.5], [1.0, 0.5, 0.5]])) n2 = dist.dirichlet(jnp.array([[1.0, 0.5, 0.5], [0.5, 0.5, 1.0]])) tu.check_close(func.compute_kl_div(n1, n2), jnp.array([0, 0.69314706]))
def test_kl_beta_beta_multi_dimensional(): n1 = dist.beta(jnp.array([1.0, 2.0]), jnp.array([0.5, 0.5])) n2 = dist.beta(jnp.array([2.0, 1.0]), jnp.array([0.5, 0.5])) tu.check_close(func.compute_kl_div(n1, n2), jnp.array([0.23342943, 0.23268747]))