def test_relaxed_beta_binomial(): total_count = torch.arange(1, 17) concentration1 = torch.logspace(-1, 2, 8).unsqueeze(-1) concentration0 = concentration1.unsqueeze(-1) d1 = beta_binomial_dist(concentration1, concentration0, total_count) assert isinstance(d1, dist.ExtendedBetaBinomial) with set_relaxed_distributions(): d2 = beta_binomial_dist(concentration1, concentration0, total_count) assert isinstance(d2, dist.Normal) assert_close(d2.mean, d1.mean) assert_close(d2.variance, d1.variance.clamp(min=_RELAX_MIN_VARIANCE))
def test_beta_binomial(concentration1, concentration0, total_count): # For small overdispersion, beta_binomial_dist is close to BetaBinomial. d1 = dist.BetaBinomial(concentration1, concentration0, total_count) d2 = beta_binomial_dist(concentration1, concentration0, total_count, overdispersion=0.01) # CRPS is equivalent to the Cramer-von Mises test. # https://en.wikipedia.org/wiki/Cram%C3%A9r%E2%80%93von_Mises_criterion k = torch.arange(0., total_count + 1.) cdf1 = d1.log_prob(k).exp().cumsum(-1) cdf2 = d2.log_prob(k).exp().cumsum(-1) crps = (cdf1 - cdf2).pow(2).mean() assert crps < 0.01
def test_overdispersed_beta_binomial(probs, total_count, overdispersion): # For high concentraion, beta_binomial_dist is close to binomial_dist. concentration = 100. # very little uncertainty concentration1 = concentration * probs concentration0 = concentration * (1 - probs) d1 = binomial_dist(total_count, probs, overdispersion=overdispersion) d2 = beta_binomial_dist(concentration1, concentration0, total_count, overdispersion=overdispersion) # CRPS is equivalent to the Cramer-von Mises test. # https://en.wikipedia.org/wiki/Cram%C3%A9r%E2%80%93von_Mises_criterion k = torch.arange(0., total_count + 1.) cdf1 = d1.log_prob(k).exp().cumsum(-1) cdf2 = d2.log_prob(k).exp().cumsum(-1) crps = (cdf1 - cdf2).pow(2).mean() assert crps < 0.01