Ejemplo n.º 1
0
def test_relaxed_overdispersed_binomial(overdispersion):
    total_count = torch.arange(1, 33)
    probs = torch.linspace(0.1, 0.9, 16).unsqueeze(-1)

    d1 = binomial_dist(total_count, probs, overdispersion=overdispersion)
    assert isinstance(d1, dist.ExtendedBetaBinomial)

    with set_relaxed_distributions():
        d2 = binomial_dist(total_count, probs, overdispersion=overdispersion)
    assert isinstance(d2, dist.Normal)
    assert_close(d2.mean, d1.mean)
    assert_close(d2.variance, d1.variance.clamp(min=_RELAX_MIN_VARIANCE))
Ejemplo n.º 2
0
def test_relaxed_beta_binomial():
    total_count = torch.arange(1, 17)
    concentration1 = torch.logspace(-1, 2, 8).unsqueeze(-1)
    concentration0 = concentration1.unsqueeze(-1)

    d1 = beta_binomial_dist(concentration1, concentration0, total_count)
    assert isinstance(d1, dist.ExtendedBetaBinomial)

    with set_relaxed_distributions():
        d2 = beta_binomial_dist(concentration1, concentration0, total_count)
    assert isinstance(d2, dist.Normal)
    assert_close(d2.mean, d1.mean)
    assert_close(d2.variance, d1.variance.clamp(min=_RELAX_MIN_VARIANCE))