Beispiel #1
0
def test_log_prob_d2(concentration):
    dist = LKJCholesky(2, torch.tensor([concentration]))
    test_dist = TransformedDistribution(Beta(concentration, concentration),
                                        AffineTransform(loc=-1., scale=2.0))

    samples = dist.sample(torch.Size([100]))
    lp = dist.log_prob(samples)
    x = samples[..., 1, 0]
    tst = test_dist.log_prob(x)
    # LKJ prevents inf values in log_prob
    lp[tst == math.inf] = math.inf  # substitute inf for comparison
    assert_tensors_equal(lp, tst, prec=1e-3)
Beispiel #2
0
def test_log_prob_conc1(dim):
    dist = LKJCholesky(dim, torch.tensor([1.]))

    a_sample = dist.sample(torch.Size([100]))
    lp = dist.log_prob(a_sample)

    if dim == 2:
        assert_equal(lp, lp.new_full(lp.size(), -math.log(2)))
    else:
        ladj = a_sample.diagonal(dim1=-2, dim2=-1).log().mul(
            torch.linspace(start=dim - 1,
                           end=0,
                           steps=dim,
                           device=a_sample.device,
                           dtype=a_sample.dtype)).sum(-1)
        lps_less_ladj = lp - ladj
        assert (lps_less_ladj - lps_less_ladj.min()).abs().sum() < 1e-4
Beispiel #3
0
def test_sample_batch():
    # Regression test for https://github.com/pyro-ppl/pyro/issues/2615
    dist = LKJCholesky(3, concentration=torch.ones(())).expand([12])
    # batch shape and event shape are as you'd expect
    assert dist.batch_shape == torch.Size([12])
    assert dist.event_shape == torch.Size([3, 3])
    # samples have correct shape when sample_shape=()
    assert dist.shape(()) == torch.Size([12, 3, 3])
    assert dist.sample().shape == torch.Size([12, 3, 3])
    # samples had the wrong shape when sample_shape is non-unit
    assert dist.shape((4, )) == torch.Size([4, 12, 3, 3])
    assert dist.sample((4, )).shape == torch.Size([4, 12, 3, 3])
Beispiel #4
0
 def __init__(self, dim, concentration=1., validate_args=None):
     base_dist = LKJCholesky(dim, concentration)
     self.dim, self.concentration = base_dist.dim, base_dist.concentration
     super(LKJ, self).__init__(base_dist,
                               CorrMatrixCholeskyTransform().inv,
                               validate_args=validate_args)