def model(key): keys = jax.random.split(key) weight = func.sample('weight', dist.normal(jnp.array(0.), jnp.array(1.)), keys[0]) measurement = func.sample('measurement', dist.normal(weight, jnp.array(1.)), keys[1]) return measurement
def model(key): keys = jax.random.split(key, 2) n1 = func.sample('n1', dist.normal(jnp.zeros((2, 2)), jnp.full((2, 2), 1.)), keys[0]) n2 = func.sample('n2', dist.normal(n1, jnp.ones((2, 2))), keys[1]) return n2
def test_kl_normal_normal_multi_dimensional(): n1 = dist.normal(jnp.array([0., 0.]), jnp.array([1., 1.])) n2 = dist.normal(jnp.array([0., 0.]), jnp.array([1., 1.])) n3 = dist.normal(jnp.array([1., 1.]), jnp.array([1., 1.])) n4 = dist.normal(jnp.array([1., 1.]), jnp.array([2., 2.])) assert jnp.all(func.compute_kl_div(n1, n2) < 1e-6) assert jnp.all(func.compute_kl_div(n1, n3) - 0.5 < 1e-6) assert jnp.all(func.compute_kl_div(n3, n4) - 0.09657359 < 1e-6)
def test_kl_normal_normal_one_dimensional(): n1 = dist.normal(jnp.array(0.), jnp.array(1.)) n2 = dist.normal(jnp.array(0.), jnp.array(1.)) n3 = dist.normal(jnp.array(1.), jnp.array(1.)) n4 = dist.normal(jnp.array(1.), jnp.array(2.)) assert func.compute_kl_div(n1, n2) < 1e-6 assert func.compute_kl_div(n1, n3) - jnp.array(0.5) \ < jnp.array(1e-6) assert func.compute_kl_div(n3, n4) \ - jnp.array(0.09657359028) < jnp.array(1e-6)
def q(params, key): n1 = func.sample('n1', dist.normal(params['n1_mean'], params['n1_std']), key) return {'n1': n1}
def model(key): n1 = func.sample('n1', dist.normal(jnp.array(10.), jnp.array(10.)), key) return n1
def model(key): keys = jax.random.split(key) n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)), keys[0]) n2 = func.sample('n2', dist.normal(n1, jnp.array(1.)), keys[1]) return n2
def proposal(key, **current): n1 = func.sample('n1', dist.normal(current['n1'], jnp.array(5.)), key) return {'n1': n1}
def proposal(key, **current): n4 = func.sample('n4', dist.normal(current['n4'], jnp.array(1.)), key) return {'n4': n4}
def model(key): n1 = func.sample('n1', dist.normal(jnp.array(0.), jnp.array(1.)), key) n2 = func.sample('n2', dist.normal(n1, jnp.array(1.)), key) return n2
def model2(key): n = func.sample('n', dist.normal(jnp.array(10.), jnp.array(1.)), key) return n
def test_log_prob_normal(): n = dist.normal(jnp.array(0.), jnp.array(1.)) log_prob_mean = n.log_prob(jnp.array(0.)) tu.check_close(log_prob_mean, -0.9189385) log_prob_std = n.log_prob(jnp.array(1.)) tu.check_close(log_prob_std, -1.4189385)