def test_lkj_cholesky_factor_prior_log_prob(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") prior = LKJCholeskyFactorPrior(2, torch.tensor(0.5, device=device)) dist = LKJCholesky(2, torch.tensor(0.5, device=device)) S = torch.eye(2, device=device) S_chol = torch.linalg.cholesky(S) self.assertAlmostEqual(prior.log_prob(S_chol), dist.log_prob(S_chol), places=4) S = torch.stack( [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S_chol.device)]) S_chol = torch.stack([torch.linalg.cholesky(Si) for Si in S]) self.assertTrue( approx_equal(prior.log_prob(S_chol), dist.log_prob(S_chol)))
def test_lkj_cholesky_factor_prior_log_prob(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") prior = LKJCholeskyFactorPrior(2, torch.tensor(0.5, device=device)) S = torch.eye(2, device=device) S_chol = torch.cholesky(S) self.assertAlmostEqual(prior.log_prob(S_chol).item(), -1.86942, places=4) S = torch.stack([S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S_chol.device)]) S_chol = torch.stack([torch.cholesky(Si) for Si in S]) self.assertTrue(approx_equal(prior.log_prob(S_chol), torch.tensor([-1.86942, -1.72558], device=S_chol.device))) with self.assertRaises(ValueError): prior.log_prob(torch.eye(3, device=device)) # For eta=1.0 log_prob is flat over all covariance matrices prior = LKJCholeskyFactorPrior(2, torch.tensor(1.0, device=device)) self.assertTrue(torch.all(prior.log_prob(S_chol) == prior.C))
def test_lkj_cholesky_factor_prior_batch_log_prob(self, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") prior = LKJCholeskyFactorPrior(2, torch.tensor([0.5, 1.5], device=device)) S = torch.eye(2, device=device) S_chol = torch.cholesky(S) self.assertTrue(approx_equal(prior.log_prob(S_chol), torch.tensor([-1.86942, -0.483129], device=S_chol.device))) S = torch.stack([S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)]) S_chol = torch.stack([torch.cholesky(Si) for Si in S]) self.assertTrue(approx_equal(prior.log_prob(S_chol), torch.tensor([-1.86942, -0.62697], device=S_chol.device))) with self.assertRaises(ValueError): prior.log_prob(torch.eye(3, device=device))
def test_lkj_cholesky_factor_prior_validate_args(self): LKJCholeskyFactorPrior(2, 1.0, validate_args=True) with self.assertRaises(ValueError): LKJCholeskyFactorPrior(1.5, 1.0, validate_args=True) with self.assertRaises(ValueError): LKJCholeskyFactorPrior(2, -1.0, validate_args=True)
def test_lkj_cholesky_factor_prior_to_gpu(self): if torch.cuda.is_available(): prior = LKJCholeskyFactorPrior(2, 1.0).cuda() self.assertEqual(prior.eta.device.type, "cuda") self.assertEqual(prior.C.device.type, "cuda")
def test_lkj_prior_sample(self): prior = LKJCholeskyFactorPrior(2, 0.5) random_samples = prior.sample(torch.Size((6, ))) self.assertTrue( _is_valid_correlation_matrix_cholesky_factor(random_samples)) self.assertEqual(random_samples.shape, torch.Size((6, 2, 2)))