def test_lkj_covariance_prior_log_prob_hetsd(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") a = torch.tensor([exp(-1), exp(-2)], device=device) b = torch.tensor([exp(1), exp(2)], device=device) sd_prior = SmoothedBoxPrior(a, b, log_transform=True) prior = LKJCovariancePrior(2, torch.tensor(0.5, device=device), sd_prior) self.assertFalse(prior.log_transform) S = torch.eye(2, device=device) self.assertAlmostEqual(prior.log_prob(S).item(), -4.71958, places=4) S = torch.stack( [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)]) self.assertTrue( approx_equal(prior.log_prob(S), torch.tensor([-4.71958, -4.57574], device=S.device))) with self.assertRaises(ValueError): prior.log_prob(torch.eye(3, device=device)) # For eta=1.0 log_prob is flat over all covariance matrices prior = LKJCovariancePrior(2, torch.tensor(1.0, device=device), sd_prior) marginal_sd = torch.diagonal(S, dim1=-2, dim2=-1).sqrt() log_prob_expected = prior.correlation_prior.C + prior.sd_prior.log_prob( marginal_sd) self.assertTrue(approx_equal(prior.log_prob(S), log_prob_expected))
def test_lkj_covariance_prior_batch_log_prob(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") v = torch.ones(2, 1, device=device) sd_prior = SmoothedBoxPrior(exp(-1) * v, exp(1) * v) prior = LKJCovariancePrior(2, torch.tensor([0.5, 1.5], device=device), sd_prior) S = torch.eye(2, device=device) self.assertTrue(approx_equal(prior.log_prob(S), torch.tensor([-3.59981, -2.21351], device=S.device))) S = torch.stack([S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)]) self.assertTrue(approx_equal(prior.log_prob(S), torch.tensor([-3.59981, -2.35735], device=S.device))) with self.assertRaises(ValueError): prior.log_prob(torch.eye(3, device=device))
def test_lkj_covariance_prior_batch_log_prob(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") v = torch.ones(2, 1, device=device) sd_prior = SmoothedBoxPrior(exp(-1) * v, exp(1) * v) prior = LKJCovariancePrior(2, torch.tensor([0.5, 1.5], device=device), sd_prior) corr_dist = LKJCholesky(2, torch.tensor([0.5, 1.5], device=device)) S = torch.eye(2, device=device) dist_log_prob = corr_dist.log_prob(S) + sd_prior.log_prob(S.diag()) self.assertLessEqual((prior.log_prob(S) - dist_log_prob).abs().sum(), 1e-4) S = torch.stack( [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)]) S_chol = torch.linalg.cholesky(S) dist_log_prob = corr_dist.log_prob(S_chol) + sd_prior.log_prob( torch.diagonal(S, dim1=-2, dim2=-1)) self.assertLessEqual((prior.log_prob(S) - dist_log_prob).abs().sum(), 1e-4)
def test_lkj_covariance_prior_log_prob_hetsd(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") a = torch.tensor([exp(-1), exp(-2)], device=device) b = torch.tensor([exp(1), exp(2)], device=device) sd_prior = SmoothedBoxPrior(a, b) prior = LKJCovariancePrior(2, torch.tensor(0.5, device=device), sd_prior) corr_dist = LKJCholesky(2, torch.tensor(0.5, device=device)) S = torch.eye(2, device=device) dist_log_prob = corr_dist.log_prob(S) + sd_prior.log_prob( S.diag()).sum() self.assertAlmostEqual(prior.log_prob(S), dist_log_prob, places=4) S = torch.stack( [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)]) S_chol = torch.linalg.cholesky(S) dist_log_prob = corr_dist.log_prob(S_chol) + sd_prior.log_prob( torch.diagonal(S, dim1=-2, dim2=-1)) self.assertTrue(approx_equal(prior.log_prob(S), dist_log_prob))