示例#1
0
    def test_lkj_prior_log_prob(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        prior = LKJPrior(2, torch.tensor(0.5, device=device))

        S = torch.eye(2, device=device)
        self.assertAlmostEqual(prior.log_prob(S).item(), -1.86942, places=4)
        S = torch.stack([S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)])
        self.assertTrue(approx_equal(prior.log_prob(S), torch.tensor([-1.86942, -1.72558], device=S.device)))
        with self.assertRaises(ValueError):
            prior.log_prob(torch.eye(3, device=device))

        # For eta=1.0 log_prob is flat over all covariance matrices
        prior = LKJPrior(2, torch.tensor(1.0, device=device))
        self.assertTrue(torch.all(prior.log_prob(S) == prior.C))
示例#2
0
    def test_lkj_prior_batch_log_prob(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        prior = LKJPrior(2, torch.tensor([0.5, 1.5], device=device))

        S = torch.eye(2, device=device)
        self.assertTrue(approx_equal(prior.log_prob(S), torch.tensor([-1.86942, -0.483129], device=S.device)))
        S = torch.stack([S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)])
        self.assertTrue(approx_equal(prior.log_prob(S), torch.tensor([-1.86942, -0.62697], device=S.device)))
        with self.assertRaises(ValueError):
            prior.log_prob(torch.eye(3, device=device))
示例#3
0
    def test_lkj_prior_sample(self, seed=0):
        torch.random.manual_seed(seed)

        prior = LKJPrior(n=5, eta=0.5)
        random_samples = prior.sample(torch.Size((8, )))
        self.assertTrue(_is_valid_correlation_matrix(random_samples))

        max_non_symm = (random_samples -
                        random_samples.transpose(-1, -2)).abs().max()
        self.assertLess(max_non_symm, 1e-4)

        self.assertEqual(random_samples.shape, torch.Size((8, 5, 5)))
示例#4
0
    def test_lkj_prior_batch_log_prob(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        prior = LKJPrior(2, torch.tensor([0.5, 1.5], device=device))
        dist = LKJCholesky(2, torch.tensor([0.5, 1.5], device=device))

        S = torch.eye(2, device=device)
        S_chol = torch.linalg.cholesky(S)
        self.assertTrue(approx_equal(prior.log_prob(S), dist.log_prob(S_chol)))
        S = torch.stack(
            [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)])
        S_chol = torch.linalg.cholesky(S)
        self.assertTrue(approx_equal(prior.log_prob(S), dist.log_prob(S_chol)))
        with self.assertRaises(ValueError):
            prior.log_prob(torch.eye(3, device=device))
示例#5
0
 def test_lkj_prior_validate_args(self):
     LKJPrior(2, 1.0, validate_args=True)
     with self.assertRaises(ValueError):
         LKJPrior(1.5, 1.0, validate_args=True)
     with self.assertRaises(ValueError):
         LKJPrior(2, -1.0, validate_args=True)
示例#6
0
 def test_lkj_prior_to_gpu(self):
     if torch.cuda.is_available():
         prior = LKJPrior(2, 1.0).cuda()
         self.assertEqual(prior.eta.device.type, "cuda")
         self.assertEqual(prior.C.device.type, "cuda")