def test_gof(continuous_dist): Dist = continuous_dist.pyro_dist if Dist in [dist.LKJ, dist.LKJCholesky]: pytest.xfail(reason="incorrect submanifold scaling") num_samples = 50000 for i in range(continuous_dist.get_num_test_data()): d = Dist(**continuous_dist.get_dist_params(i)) samples = d.sample(torch.Size([num_samples])) with xfail_if_not_implemented(): probs = d.log_prob(samples).exp() dim = None if "ProjectedNormal" in Dist.__name__: dim = samples.size(-1) - 1 # Test each batch independently. probs = probs.reshape(num_samples, -1) samples = samples.reshape(probs.shape + d.event_shape) if "Dirichlet" in Dist.__name__: # The Dirichlet density is over all but one of the probs. samples = samples[..., :-1] for b in range(probs.size(-1)): gof = auto_goodness_of_fit(samples[:, b], probs[:, b], dim=dim) assert gof > TEST_FAILURE_RATE
def test_von_mises_3d_gof(scale): concentration = torch.randn(3) concentration = concentration * (scale / concentration.norm(2)) d = VonMises3D(concentration, validate_args=True) with xfail_if_not_implemented(): samples = d.sample(torch.Size([2000])) probs = d.log_prob(samples).exp() gof = auto_goodness_of_fit(samples, probs, dim=2) assert gof > TEST_FAILURE_RATE
def test_von_mises_gof(loc, concentration): d = VonMises(loc, concentration) samples = d.sample(torch.Size([100000])) probs = d.log_prob(samples).exp() gof = auto_goodness_of_fit(samples, probs, dim=1) assert gof > TEST_FAILURE_RATE