Exemple #1
0
def test_relaxed_beta_binomial():
    total_count = torch.arange(1, 17)
    concentration1 = torch.logspace(-1, 2, 8).unsqueeze(-1)
    concentration0 = concentration1.unsqueeze(-1)

    d1 = beta_binomial_dist(concentration1, concentration0, total_count)
    assert isinstance(d1, dist.ExtendedBetaBinomial)

    with set_relaxed_distributions():
        d2 = beta_binomial_dist(concentration1, concentration0, total_count)
    assert isinstance(d2, dist.Normal)
    assert_close(d2.mean, d1.mean)
    assert_close(d2.variance, d1.variance.clamp(min=_RELAX_MIN_VARIANCE))
Exemple #2
0
def test_beta_binomial(concentration1, concentration0, total_count):
    # For small overdispersion, beta_binomial_dist is close to BetaBinomial.
    d1 = dist.BetaBinomial(concentration1, concentration0, total_count)
    d2 = beta_binomial_dist(concentration1, concentration0, total_count,
                            overdispersion=0.01)

    # CRPS is equivalent to the Cramer-von Mises test.
    # https://en.wikipedia.org/wiki/Cram%C3%A9r%E2%80%93von_Mises_criterion
    k = torch.arange(0., total_count + 1.)
    cdf1 = d1.log_prob(k).exp().cumsum(-1)
    cdf2 = d2.log_prob(k).exp().cumsum(-1)
    crps = (cdf1 - cdf2).pow(2).mean()
    assert crps < 0.01
Exemple #3
0
def test_overdispersed_beta_binomial(probs, total_count, overdispersion):
    # For high concentraion, beta_binomial_dist is close to binomial_dist.
    concentration = 100.  # very little uncertainty
    concentration1 = concentration * probs
    concentration0 = concentration * (1 - probs)
    d1 = binomial_dist(total_count, probs, overdispersion=overdispersion)
    d2 = beta_binomial_dist(concentration1, concentration0, total_count,
                            overdispersion=overdispersion)

    # CRPS is equivalent to the Cramer-von Mises test.
    # https://en.wikipedia.org/wiki/Cram%C3%A9r%E2%80%93von_Mises_criterion
    k = torch.arange(0., total_count + 1.)
    cdf1 = d1.log_prob(k).exp().cumsum(-1)
    cdf2 = d2.log_prob(k).exp().cumsum(-1)
    crps = (cdf1 - cdf2).pow(2).mean()
    assert crps < 0.01