Exemplo n.º 1
0
    def test_smoothed_box_prior_log_prob_log_transform(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        a, b = torch.zeros(2, device=device), torch.ones(2, device=device)
        sigma = 0.1
        prior = SmoothedBoxPrior(a, b, sigma, transform=torch.exp)

        t = torch.tensor([0.5, 1.1], device=device).log()
        self.assertAlmostEqual(prior.log_prob(t).item(), -0.9473, places=4)
        t = torch.tensor([[0.5, 1.1], [0.1, 0.25]], device=device).log()
        log_prob_expected = torch.tensor([-0.947347, -0.447347], device=t.device)
        self.assertTrue(torch.all(approx_equal(prior.log_prob(t), log_prob_expected)))
        with self.assertRaises(RuntimeError):
            prior.log_prob(torch.ones(3, device=device))
Exemplo n.º 2
0
    def test_lkj_covariance_prior_batch_log_prob(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        v = torch.ones(2, 1, device=device)
        sd_prior = SmoothedBoxPrior(exp(-1) * v, exp(1) * v)
        prior = LKJCovariancePrior(2, torch.tensor([0.5, 1.5], device=device),
                                   sd_prior)
        corr_dist = LKJCholesky(2, torch.tensor([0.5, 1.5], device=device))

        S = torch.eye(2, device=device)
        dist_log_prob = corr_dist.log_prob(S) + sd_prior.log_prob(S.diag())
        self.assertLessEqual((prior.log_prob(S) - dist_log_prob).abs().sum(),
                             1e-4)

        S = torch.stack(
            [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)])
        S_chol = torch.linalg.cholesky(S)
        dist_log_prob = corr_dist.log_prob(S_chol) + sd_prior.log_prob(
            torch.diagonal(S, dim1=-2, dim2=-1))
        self.assertLessEqual((prior.log_prob(S) - dist_log_prob).abs().sum(),
                             1e-4)
Exemplo n.º 3
0
    def test_lkj_covariance_prior_log_prob_hetsd(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        a = torch.tensor([exp(-1), exp(-2)], device=device)
        b = torch.tensor([exp(1), exp(2)], device=device)
        sd_prior = SmoothedBoxPrior(a, b)
        prior = LKJCovariancePrior(2, torch.tensor(0.5, device=device),
                                   sd_prior)
        corr_dist = LKJCholesky(2, torch.tensor(0.5, device=device))

        S = torch.eye(2, device=device)
        dist_log_prob = corr_dist.log_prob(S) + sd_prior.log_prob(
            S.diag()).sum()
        self.assertAlmostEqual(prior.log_prob(S), dist_log_prob, places=4)

        S = torch.stack(
            [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)])
        S_chol = torch.linalg.cholesky(S)
        dist_log_prob = corr_dist.log_prob(S_chol) + sd_prior.log_prob(
            torch.diagonal(S, dim1=-2, dim2=-1))
        self.assertTrue(approx_equal(prior.log_prob(S), dist_log_prob))