Exemplo n.º 1
0
def test_beta_binomial_log_prob(total_count, shape):
    concentration0 = onp.exp(onp.random.normal(size=shape))
    concentration1 = onp.exp(onp.random.normal(size=shape))
    value = np.arange(1 + total_count)

    num_samples = 100000
    probs = onp.random.beta(concentration1, concentration0, size=(num_samples,) + shape)
    log_probs = dist.Binomial(total_count, probs).log_prob(value)
    expected = logsumexp(log_probs, 0) - np.log(num_samples)

    actual = dist.BetaBinomial(concentration1, concentration0, total_count).log_prob(value)
    assert_allclose(actual, expected, rtol=0.02)
Exemplo n.º 2
0
def model_PMD(z, N, y=None, phi_prior=1 / 1000):
    z = jnp.abs(z)

    q = numpyro.sample("q", dist.Beta(2, 3))  # mean = 0.4, shape = 5
    A = numpyro.sample("A", dist.Beta(2, 3))  # mean = 0.4, shape = 5
    c = numpyro.sample("c", dist.Beta(1, 9))  # mean = 0.1, shape = 10
    # Dz = numpyro.deterministic("Dz", A * (1 - q) ** (z - 1) + c)
    Dz = jnp.clip(numpyro.deterministic("Dz", A * (1 - q)**(z - 1) + c), 0, 1)
    D_max = numpyro.deterministic("D_max", A + c)  # pylint: disable=unused-variable

    delta = numpyro.sample("delta", dist.Exponential(phi_prior))
    phi = numpyro.deterministic("phi", delta + 2)

    alpha = numpyro.deterministic("alpha", Dz * phi)
    beta = numpyro.deterministic("beta", (1 - Dz) * phi)

    numpyro.sample("obs", dist.BetaBinomial(alpha, beta, N), obs=y)
Exemplo n.º 3
0
def model_null(z, N, y=None, phi_prior=1 / 1000):
    q = numpyro.sample("q", dist.Beta(2, 3))  # mean = 0.4, shape = 5
    D_max = numpyro.deterministic("D_max", q)
    delta = numpyro.sample("delta", dist.Exponential(phi_prior))
    phi = numpyro.deterministic("phi", delta + 2)
    numpyro.sample("obs", dist.BetaBinomial(q * phi, (1 - q) * phi, N), obs=y)
Exemplo n.º 4
0
 def model2():
     c1 = numpyro.param("c1", 0.5, constraint=dist.constraints.positive)
     c0 = numpyro.param("c0", 1.5, constraint=dist.constraints.positive)
     numpyro.sample("obs", dist.BetaBinomial(c1, c0, total_count), obs=data)