def model4(key): return func.sample('n', dist.bernoulli(jnp.full((2, 2), 0.5)), key)
def model3(key): return func.sample('n', dist.bernoulli(jnp.array(1.)), key)
def test_log_prob_bernoulli(): n1 = dist.bernoulli(jnp.array(0.8)) log_prob_0 = n1.log_prob(0) log_prob_1 = n1.log_prob(1) tu.check_close(log_prob_0, -1.609438) tu.check_close(log_prob_1, -0.22314353)
def model(key): keys = jax.random.split(key) n1 = func.sample('n1', dist.bernoulli(jnp.array(0.5)), keys[0]) n2 = func.sample('n2', dist.binomial(jnp.array(1), n1), keys[1]) return n2
def model(key): keys = jax.random.split(key) n1 = func.sample('n1', dist.bernoulli(jnp.array(2.0)), keys[1]) return n1
def model(key): keys = jax.random.split(key, 2) n1 = func.sample('n1', dist.beta(jnp.array(0.5), jnp.array(0.5)), keys[0]) n2 = func.sample('n2', dist.bernoulli(n1), keys[1]) return n2