예제 #1
0
def test_gof(continuous_dist):
    Dist = continuous_dist.pyro_dist
    if Dist in [dist.LKJ, dist.LKJCholesky]:
        pytest.xfail(reason="incorrect submanifold scaling")

    num_samples = 50000
    for i in range(continuous_dist.get_num_test_data()):
        d = Dist(**continuous_dist.get_dist_params(i))
        samples = d.sample(torch.Size([num_samples]))
        with xfail_if_not_implemented():
            probs = d.log_prob(samples).exp()

        dim = None
        if "ProjectedNormal" in Dist.__name__:
            dim = samples.size(-1) - 1

        # Test each batch independently.
        probs = probs.reshape(num_samples, -1)
        samples = samples.reshape(probs.shape + d.event_shape)
        if "Dirichlet" in Dist.__name__:
            # The Dirichlet density is over all but one of the probs.
            samples = samples[..., :-1]
        for b in range(probs.size(-1)):
            gof = auto_goodness_of_fit(samples[:, b], probs[:, b], dim=dim)
            assert gof > TEST_FAILURE_RATE
예제 #2
0
def test_von_mises_3d_gof(scale):
    concentration = torch.randn(3)
    concentration = concentration * (scale / concentration.norm(2))
    d = VonMises3D(concentration, validate_args=True)

    with xfail_if_not_implemented():
        samples = d.sample(torch.Size([2000]))
    probs = d.log_prob(samples).exp()

    gof = auto_goodness_of_fit(samples, probs, dim=2)
    assert gof > TEST_FAILURE_RATE
예제 #3
0
def test_von_mises_gof(loc, concentration):
    d = VonMises(loc, concentration)
    samples = d.sample(torch.Size([100000]))
    probs = d.log_prob(samples).exp()
    gof = auto_goodness_of_fit(samples, probs, dim=1)
    assert gof > TEST_FAILURE_RATE