Ejemplo n.º 1
0
    def test_lkj_cholesky_factor_prior_batch_log_prob(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        prior = LKJCholeskyFactorPrior(2, torch.tensor([0.5, 1.5], device=device))

        S = torch.eye(2, device=device)
        S_chol = torch.cholesky(S)
        self.assertTrue(approx_equal(prior.log_prob(S_chol), torch.tensor([-1.86942, -0.483129], device=S_chol.device)))
        S = torch.stack([S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)])
        S_chol = torch.stack([torch.cholesky(Si) for Si in S])
        self.assertTrue(approx_equal(prior.log_prob(S_chol), torch.tensor([-1.86942, -0.62697], device=S_chol.device)))
        with self.assertRaises(ValueError):
            prior.log_prob(torch.eye(3, device=device))
Ejemplo n.º 2
0
 def test_lkj_cholesky_factor_prior_log_prob(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     prior = LKJCholeskyFactorPrior(2, torch.tensor(0.5, device=device))
     dist = LKJCholesky(2, torch.tensor(0.5, device=device))
     S = torch.eye(2, device=device)
     S_chol = torch.linalg.cholesky(S)
     self.assertAlmostEqual(prior.log_prob(S_chol),
                            dist.log_prob(S_chol),
                            places=4)
     S = torch.stack(
         [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S_chol.device)])
     S_chol = torch.stack([torch.linalg.cholesky(Si) for Si in S])
     self.assertTrue(
         approx_equal(prior.log_prob(S_chol), dist.log_prob(S_chol)))
Ejemplo n.º 3
0
    def test_lkj_cholesky_factor_prior_log_prob(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        prior = LKJCholeskyFactorPrior(2, torch.tensor(0.5, device=device))
        S = torch.eye(2, device=device)
        S_chol = torch.cholesky(S)
        self.assertAlmostEqual(prior.log_prob(S_chol).item(), -1.86942, places=4)
        S = torch.stack([S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S_chol.device)])
        S_chol = torch.stack([torch.cholesky(Si) for Si in S])
        self.assertTrue(approx_equal(prior.log_prob(S_chol), torch.tensor([-1.86942, -1.72558], device=S_chol.device)))
        with self.assertRaises(ValueError):
            prior.log_prob(torch.eye(3, device=device))

        # For eta=1.0 log_prob is flat over all covariance matrices
        prior = LKJCholeskyFactorPrior(2, torch.tensor(1.0, device=device))
        self.assertTrue(torch.all(prior.log_prob(S_chol) == prior.C))