def test_sample_batch(): # Regression test for https://github.com/pyro-ppl/pyro/issues/2615 dist = LKJCholesky(3, concentration=torch.ones(())).expand([12]) # batch shape and event shape are as you'd expect assert dist.batch_shape == torch.Size([12]) assert dist.event_shape == torch.Size([3, 3]) # samples have correct shape when sample_shape=() assert dist.shape(()) == torch.Size([12, 3, 3]) assert dist.sample().shape == torch.Size([12, 3, 3]) # samples had the wrong shape when sample_shape is non-unit assert dist.shape((4, )) == torch.Size([4, 12, 3, 3]) assert dist.sample((4, )).shape == torch.Size([4, 12, 3, 3])
def test_log_prob_d2(concentration): dist = LKJCholesky(2, torch.tensor([concentration])) test_dist = TransformedDistribution(Beta(concentration, concentration), AffineTransform(loc=-1., scale=2.0)) samples = dist.sample(torch.Size([100])) lp = dist.log_prob(samples) x = samples[..., 1, 0] tst = test_dist.log_prob(x) # LKJ prevents inf values in log_prob lp[tst == math.inf] = math.inf # substitute inf for comparison assert_tensors_equal(lp, tst, prec=1e-3)
def test_log_prob_conc1(dim): dist = LKJCholesky(dim, torch.tensor([1.])) a_sample = dist.sample(torch.Size([100])) lp = dist.log_prob(a_sample) if dim == 2: assert_equal(lp, lp.new_full(lp.size(), -math.log(2))) else: ladj = a_sample.diagonal(dim1=-2, dim2=-1).log().mul( torch.linspace(start=dim - 1, end=0, steps=dim, device=a_sample.device, dtype=a_sample.dtype)).sum(-1) lps_less_ladj = lp - ladj assert (lps_less_ladj - lps_less_ladj.min()).abs().sum() < 1e-4