def test_kl_binomial_binomial_one_dimensional(): n1 = dist.binomial(jnp.array(10), jnp.array(0.5)) n2 = dist.binomial(jnp.array(5), jnp.array(0.5)) n3 = dist.binomial(jnp.array(10), jnp.array(0.5)) n4 = dist.binomial(jnp.array(10), jnp.array(0.1)) with pytest.raises(ValueError): func.compute_kl_div(n1, n2) assert func.compute_kl_div(n1, n3) < jnp.array(1e-6) assert func.compute_kl_div(n1, n4) - 5.1082563 < jnp.array(1e-6)
def test_kl_normal_normal_multi_dimensional(): n1 = dist.binomial(jnp.array([10, 5]), jnp.array([0.5, 0.5])) n2 = dist.binomial(jnp.array([5, 5]), jnp.array([0.5, 0.5])) n3 = dist.binomial(jnp.array([10, 5]), jnp.array([0.5, 0.5])) n4 = dist.binomial(jnp.array([10, 5]), jnp.array([0.1, 0.5])) with pytest.raises(ValueError): func.compute_kl_div(n1, n2) assert jnp.all(jnp.abs(func.compute_kl_div(n1, n3)) < 1e-6) assert jnp.all( jnp.abs(func.compute_kl_div(n1, n4) - jnp.array([5.1082563, 0.0])) < 1e-6)
def model4(key): return func.sample('n', dist.binomial(jnp.array(10), jnp.array(0.5)), key)
def model(key): keys = jax.random.split(key) n1 = func.sample('n1', dist.bernoulli(jnp.array(0.5)), keys[0]) n2 = func.sample('n2', dist.binomial(jnp.array(1), n1), keys[1]) return n2