Exemple #1
0
    def test_solve_qr(self, dtype=torch.float64, tol=1e-8):
        size = 50
        X = torch.rand((size, 2)).to(dtype=dtype)
        y = torch.sin(torch.sum(X, 1)).unsqueeze(-1).to(dtype=dtype)
        with settings.min_preconditioning_size(0):
            noise = torch.DoubleTensor(size).uniform_(
                math.log(1e-3), math.log(1e-1)).exp_().to(dtype=dtype)
            lazy_tsr = RBFKernel().to(
                dtype=dtype)(X).evaluate_kernel().add_diag(noise)
            precondition_qr, _, logdet_qr = lazy_tsr._preconditioner()

            F = lazy_tsr._piv_chol_self
            M = noise.diag() + F.matmul(F.t())

        x_exact = torch.solve(y, M)[0]
        x_qr = precondition_qr(y)

        self.assertTrue(approx_equal(x_exact, x_qr, tol))

        logdet = 2 * torch.cholesky(M).diag().log().sum(-1)
        self.assertTrue(approx_equal(logdet, logdet_qr, tol))
Exemple #2
0
    def test_lanczos(self):
        size = 100
        matrix = torch.randn(size, size)
        matrix = matrix.matmul(matrix.transpose(-1, -2))
        matrix.div_(matrix.norm())
        matrix.add_(torch.ones(matrix.size(-1)).mul(1e-6).diag())
        q_mat, t_mat = lanczos_tridiag(
            matrix.matmul, max_iter=size, dtype=matrix.dtype, device=matrix.device, matrix_shape=matrix.shape
        )

        approx = q_mat.matmul(t_mat).matmul(q_mat.transpose(-1, -2))
        self.assertTrue(approx_equal(approx, matrix))
Exemple #3
0
    def lanczos_tridiag_test(self, matrix):
        size = matrix.shape[0]
        q_mat, t_mat = lanczos_tridiag(matrix.matmul,
                                       max_iter=size,
                                       dtype=matrix.dtype,
                                       device=matrix.device,
                                       matrix_shape=matrix.shape)

        self.assert_valid_sizes(size, t_mat, q_mat)
        self.assert_tridiagonally_positive(t_mat)
        approx = q_mat.matmul(t_mat).matmul(q_mat.transpose(-1, -2))
        self.assertTrue(approx_equal(approx, matrix))
Exemple #4
0
    def test_solve(self):
        size = 100
        train_x = torch.linspace(0, 1, size)
        covar_matrix = RBFKernel()(train_x, train_x).evaluate()
        piv_chol = pivoted_cholesky.pivoted_cholesky(covar_matrix, 10)
        woodbury_factor, inv_scale, logdet = woodbury.woodbury_factor(
            piv_chol, piv_chol, torch.ones(100), logdet=True)
        self.assertTrue(
            approx_equal(logdet, (piv_chol @ piv_chol.transpose(-1, -2) +
                                  torch.eye(100)).logdet(), 2e-4))

        rhs_vector = torch.randn(100, 50)
        shifted_covar_matrix = covar_matrix + torch.eye(size)
        real_solve = shifted_covar_matrix.inverse().matmul(rhs_vector)
        scaled_inv_diag = (inv_scale / torch.ones(100)).unsqueeze(-1)
        approx_solve = woodbury.woodbury_solve(rhs_vector,
                                               piv_chol * scaled_inv_diag,
                                               woodbury_factor,
                                               scaled_inv_diag, inv_scale)

        self.assertTrue(approx_equal(approx_solve, real_solve, 2e-4))
Exemple #5
0
    def test_interpolation(self):
        x = torch.linspace(0.01, 1, 100).unsqueeze(1)
        grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(1)
        indices, values = Interpolation().interpolate(grid, x)
        indices = indices.squeeze_(0)
        values = values.squeeze_(0)
        test_func_grid = grid.squeeze(1).pow(2)
        test_func_x = x.pow(2).squeeze(-1)

        interp_func_x = left_interp(indices, values,
                                    test_func_grid.unsqueeze(1)).squeeze()

        self.assertTrue(approx_equal(interp_func_x, test_func_x))
    def test_smoothed_box_prior_log_prob_log_transform(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        a, b = torch.zeros(2, device=device), torch.ones(2, device=device)
        sigma = 0.1
        prior = SmoothedBoxPrior(a, b, sigma, transform=torch.exp)

        t = torch.tensor([0.5, 1.1], device=device).log()
        self.assertAlmostEqual(prior.log_prob(t).item(), -0.9473, places=4)
        t = torch.tensor([[0.5, 1.1], [0.1, 0.25]], device=device).log()
        log_prob_expected = torch.tensor([-0.947347, -0.447347], device=t.device)
        self.assertTrue(torch.all(approx_equal(prior.log_prob(t), log_prob_expected)))
        with self.assertRaises(RuntimeError):
            prior.log_prob(torch.ones(3, device=device))
Exemple #7
0
    def test_lkj_prior_log_prob(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        prior = LKJPrior(2, torch.tensor(0.5, device=device))

        S = torch.eye(2, device=device)
        self.assertAlmostEqual(prior.log_prob(S).item(), -1.86942, places=4)
        S = torch.stack([S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)])
        self.assertTrue(approx_equal(prior.log_prob(S), torch.tensor([-1.86942, -1.72558], device=S.device)))
        with self.assertRaises(ValueError):
            prior.log_prob(torch.eye(3, device=device))

        # For eta=1.0 log_prob is flat over all covariance matrices
        prior = LKJPrior(2, torch.tensor(1.0, device=device))
        self.assertTrue(torch.all(prior.log_prob(S) == prior.C))
 def test_lkj_cholesky_factor_prior_log_prob(self, cuda=False):
     device = torch.device("cuda") if cuda else torch.device("cpu")
     prior = LKJCholeskyFactorPrior(2, torch.tensor(0.5, device=device))
     dist = LKJCholesky(2, torch.tensor(0.5, device=device))
     S = torch.eye(2, device=device)
     S_chol = torch.linalg.cholesky(S)
     self.assertAlmostEqual(prior.log_prob(S_chol),
                            dist.log_prob(S_chol),
                            places=4)
     S = torch.stack(
         [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S_chol.device)])
     S_chol = torch.stack([torch.linalg.cholesky(Si) for Si in S])
     self.assertTrue(
         approx_equal(prior.log_prob(S_chol), dist.log_prob(S_chol)))
Exemple #9
0
    def test_toeplitz_matmul_batch(self):
        cols = torch.tensor([[1, 6, 4, 5], [2, 3, 1, 0], [1, 2, 3, 1]], dtype=torch.float)
        rows = torch.tensor([[1, 2, 1, 1], [2, 0, 0, 1], [1, 5, 1, 0]], dtype=torch.float)

        rhs_mats = torch.randn(3, 4, 2)

        # Actual
        lhs_mats = torch.zeros(3, 4, 4)
        for i, (col, row) in enumerate(zip(cols, rows)):
            lhs_mats[i].copy_(utils.toeplitz.toeplitz(col, row))
        actual = torch.matmul(lhs_mats, rhs_mats)

        # Fast toeplitz
        res = utils.toeplitz.toeplitz_matmul(cols, rows, rhs_mats)
        self.assertTrue(approx_equal(res, actual))
Exemple #10
0
    def test_lkj_covariance_prior_log_prob_hetsd(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        a = torch.tensor([exp(-1), exp(-2)], device=device)
        b = torch.tensor([exp(1), exp(2)], device=device)
        sd_prior = SmoothedBoxPrior(a, b)
        prior = LKJCovariancePrior(2, torch.tensor(0.5, device=device),
                                   sd_prior)
        S = torch.eye(2, device=device)
        self.assertAlmostEqual(prior.log_prob(S).item(), -4.71958, places=4)
        S = torch.stack(
            [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)])
        self.assertTrue(
            approx_equal(prior.log_prob(S),
                         torch.tensor([-4.71958, -4.57574], device=S.device)))
        with self.assertRaises(ValueError):
            prior.log_prob(torch.eye(3, device=device))

        # For eta=1.0 log_prob is flat over all covariance matrices
        prior = LKJCovariancePrior(2, torch.tensor(1.0, device=device),
                                   sd_prior)
        marginal_sd = torch.diagonal(S, dim1=-2, dim2=-1).sqrt()
        log_prob_expected = prior.correlation_prior.C + prior.sd_prior.log_prob(
            marginal_sd)
        self.assertTrue(approx_equal(prior.log_prob(S), log_prob_expected))
    def test_lkj_prior_log_prob(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        prior = LKJPrior(2, torch.tensor(0.5, device=device))
        dist = LKJCholesky(2, torch.tensor(0.5, device=device))

        S = torch.eye(2, device=device)
        S_chol = torch.linalg.cholesky(S)
        self.assertAlmostEqual(prior.log_prob(S),
                               dist.log_prob(S_chol),
                               places=4)
        S = torch.stack(
            [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)])
        S_chol = torch.linalg.cholesky(S)
        self.assertTrue(approx_equal(prior.log_prob(S), dist.log_prob(S_chol)))
        with self.assertRaises(ValueError):
            prior.log_prob(torch.eye(3, device=device))
    def test_get_item_tensor_index_on_batch(self):
        # Tests the default LV.__getitem__ behavior
        lazy_tensor = ZeroLazyTensor(3, 5, 5)
        evaluated = lazy_tensor.evaluate()

        index = (torch.tensor([0, 1, 1, 0]), torch.tensor([0, 1, 0, 2]),
                 torch.tensor([1, 2, 0, 1]))
        self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index]))
        index = (torch.tensor([0, 1, 1,
                               0]), torch.tensor([0, 1, 0,
                                                  2]), slice(None, None, None))
        self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index]))
        index = (torch.tensor([0, 1, 1]), slice(None, None,
                                                None), torch.tensor([0, 1, 2]))
        self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index]))
        index = (slice(None, None,
                       None), torch.tensor([0, 1, 1,
                                            0]), torch.tensor([0, 1, 0, 2]))
        self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index]))
        index = (torch.tensor([0, 0, 1,
                               1]), slice(None, None,
                                          None), slice(None, None, None))
        self.assertTrue(
            approx_equal(lazy_tensor[index].evaluate(), evaluated[index]))
        index = (slice(None, None,
                       None), torch.tensor([0, 0, 1,
                                            2]), torch.tensor([0, 0, 1, 1]))
        self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index]))
        index = (torch.tensor([0, 1, 1,
                               0]), torch.tensor([0, 1, 0,
                                                  2]), slice(None, None, None))
        self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index]))
        index = (torch.tensor([0, 0, 1,
                               0]), slice(None, None,
                                          None), torch.tensor([0, 0, 1, 1]))
        self.assertTrue(approx_equal(lazy_tensor[index], evaluated[index]))
        index = (Ellipsis, torch.tensor([0, 1, 1, 0]))
        self.assertTrue(
            approx_equal(lazy_tensor[index].evaluate(), evaluated[index]))
    def test_lkj_covariance_prior_log_prob_hetsd(self, cuda=False):
        device = torch.device("cuda") if cuda else torch.device("cpu")
        a = torch.tensor([exp(-1), exp(-2)], device=device)
        b = torch.tensor([exp(1), exp(2)], device=device)
        sd_prior = SmoothedBoxPrior(a, b)
        prior = LKJCovariancePrior(2, torch.tensor(0.5, device=device),
                                   sd_prior)
        corr_dist = LKJCholesky(2, torch.tensor(0.5, device=device))

        S = torch.eye(2, device=device)
        dist_log_prob = corr_dist.log_prob(S) + sd_prior.log_prob(
            S.diag()).sum()
        self.assertAlmostEqual(prior.log_prob(S), dist_log_prob, places=4)

        S = torch.stack(
            [S, torch.tensor([[1.0, 0.5], [0.5, 1]], device=S.device)])
        S_chol = torch.linalg.cholesky(S)
        dist_log_prob = corr_dist.log_prob(S_chol) + sd_prior.log_prob(
            torch.diagonal(S, dim1=-2, dim2=-1))
        self.assertTrue(approx_equal(prior.log_prob(S), dist_log_prob))
Exemple #14
0
    def test_sample(self):
        a = torch.as_tensor(0.0)
        b = torch.as_tensor(1.0)
        sigma = 0.01

        gauss_max = 1 / (math.sqrt(2 * math.pi) * sigma)
        ratio_gaussian_mass = 1 / (gauss_max * (b - a) + 1)

        prior = SmoothedBoxPrior(a, b, sigma)

        n_samples = 50000
        samples = prior.sample((n_samples, ))

        gaussian_idx = (samples < a) | (samples > b)
        gaussian_samples = samples[gaussian_idx]
        n_gaussian = gaussian_samples.shape[0]

        self.assertTrue(
            torch.all(
                approx_equal(torch.as_tensor(n_gaussian / n_samples),
                             ratio_gaussian_mass,
                             epsilon=0.005)))
    def test_add_diag(self):
        diag = torch.tensor(1.5)
        res = ZeroLazyTensor(5, 5).add_diag(diag).evaluate()
        actual = torch.eye(5).mul(1.5)
        self.assertTrue(approx_equal(res, actual))

        diag = torch.tensor([1.5])
        res = ZeroLazyTensor(5, 5).add_diag(diag).evaluate()
        actual = torch.eye(5).mul(1.5)
        self.assertTrue(approx_equal(res, actual))

        diag = torch.tensor([1.5, 1.3, 1.2, 1.1, 2.0])
        res = ZeroLazyTensor(5, 5).add_diag(diag).evaluate()
        actual = diag.diag()
        self.assertTrue(approx_equal(res, actual))

        diag = torch.tensor(1.5)
        res = ZeroLazyTensor(2, 5, 5).add_diag(diag).evaluate()
        actual = torch.eye(5).unsqueeze(0).repeat(2, 1, 1).mul(1.5)
        self.assertTrue(approx_equal(res, actual))

        diag = torch.tensor([1.5])
        res = ZeroLazyTensor(2, 5, 5).add_diag(diag).evaluate()
        actual = torch.eye(5).unsqueeze(0).repeat(2, 1, 1).mul(1.5)
        self.assertTrue(approx_equal(res, actual))

        diag = torch.tensor([1.5, 1.3, 1.2, 1.1, 2.0])
        res = ZeroLazyTensor(2, 5, 5).add_diag(diag).evaluate()
        actual = diag.diag().unsqueeze(0).repeat(2, 1, 1)
        self.assertTrue(approx_equal(res, actual))

        diag = torch.tensor([[1.5, 1.3, 1.2, 1.1, 2.0], [0, 1, 2, 1, 1]])
        res = ZeroLazyTensor(2, 5, 5).add_diag(diag).evaluate()
        actual = torch.cat(
            [diag[0].diag().unsqueeze(0), diag[1].diag().unsqueeze(0)])
        self.assertTrue(approx_equal(res, actual))
Exemple #16
0
    def test_left_interp_on_a_vector(self):
        vector = torch.randn(6)

        res = left_interp(self.interp_indices, self.interp_values, vector)
        actual = torch.matmul(self.interp_matrix, vector)
        self.assertTrue(approx_equal(res, actual))
Exemple #17
0
    def test_multidim_interpolation(self):
        x = torch.tensor([[0.25, 0.45, 0.65, 0.85], [0.35, 0.375, 0.4, 0.425],
                          [0.45, 0.5, 0.55, 0.6]]).t().contiguous()
        grid = torch.linspace(0.0, 1.0, 11).unsqueeze(1).repeat(1, 3)

        indices, values = Interpolation().interpolate(grid, x)

        actual_indices = torch.cat(
            [
                torch.tensor(
                    [
                        [
                            146, 147, 148, 149, 157, 158, 159, 160, 168, 169,
                            170, 171, 179
                        ],
                        [
                            389, 390, 391, 392, 400, 401, 402, 403, 411, 412,
                            413, 414, 422
                        ],
                        [
                            642, 643, 644, 645, 653, 654, 655, 656, 664, 665,
                            666, 667, 675
                        ],
                        [
                            885, 886, 887, 888, 896, 897, 898, 899, 907, 908,
                            909, 910, 918
                        ],
                    ],
                    dtype=torch.long,
                ),
                torch.tensor(
                    [
                        [
                            180, 181, 182, 267, 268, 269, 270, 278, 279, 280,
                            281, 289, 290
                        ],
                        [
                            423, 424, 425, 510, 511, 512, 513, 521, 522, 523,
                            524, 532, 533
                        ],
                        [
                            676, 677, 678, 763, 764, 765, 766, 774, 775, 776,
                            777, 785, 786
                        ],
                        [
                            919, 920, 921, 1006, 1007, 1008, 1009, 1017, 1018,
                            1019, 1020, 1028, 1029
                        ],
                    ],
                    dtype=torch.long,
                ),
                torch.tensor(
                    [
                        [
                            291, 292, 300, 301, 302, 303, 388, 389, 390, 391,
                            399, 400, 401
                        ],
                        [
                            534, 535, 543, 544, 545, 546, 631, 632, 633, 634,
                            642, 643, 644
                        ],
                        [
                            787, 788, 796, 797, 798, 799, 884, 885, 886, 887,
                            895, 896, 897
                        ],
                        [
                            1030, 1031, 1039, 1040, 1041, 1042, 1127, 1128,
                            1129, 1130, 1138, 1139, 1140
                        ],
                    ],
                    dtype=torch.long,
                ),
                torch.tensor(
                    [
                        [
                            402, 410, 411, 412, 413, 421, 422, 423, 424, 509,
                            510, 511, 512
                        ],
                        [
                            645, 653, 654, 655, 656, 664, 665, 666, 667, 752,
                            753, 754, 755
                        ],
                        [
                            898, 906, 907, 908, 909, 917, 918, 919, 920, 1005,
                            1006, 1007, 1008
                        ],
                        [
                            1141, 1149, 1150, 1151, 1152, 1160, 1161, 1162,
                            1163, 1248, 1249, 1250, 1251
                        ],
                    ],
                    dtype=torch.long,
                ),
                torch.tensor(
                    [
                        [
                            520, 521, 522, 523, 531, 532, 533, 534, 542, 543,
                            544, 545
                        ],
                        [
                            763, 764, 765, 766, 774, 775, 776, 777, 785, 786,
                            787, 788
                        ],
                        [
                            1016, 1017, 1018, 1019, 1027, 1028, 1029, 1030,
                            1038, 1039, 1040, 1041
                        ],
                        [
                            1259, 1260, 1261, 1262, 1270, 1271, 1272, 1273,
                            1281, 1282, 1283, 1284
                        ],
                    ],
                    dtype=torch.long,
                ),
            ],
            1,
        )
        self.assertTrue(approx_equal(indices, actual_indices))

        actual_values = torch.cat(
            [
                torch.tensor([
                    [
                        -0.0002, 0.0022, 0.0022, -0.0002, 0.0022, -0.0198,
                        -0.0198, 0.0022, 0.0022, -0.0198
                    ],
                    [
                        0.0000, 0.0015, 0.0000, 0.0000, -0.0000, -0.0142,
                        -0.0000, -0.0000, -0.0000, -0.0542
                    ],
                    [
                        0.0000, -0.0000, -0.0000, 0.0000, 0.0039, -0.0352,
                        -0.0352, 0.0039, 0.0000, -0.0000
                    ],
                    [
                        0.0000, 0.0044, 0.0000, 0.0000, -0.0000, -0.0542,
                        -0.0000, -0.0000, -0.0000, -0.0142
                    ],
                ]),
                torch.tensor([
                    [
                        -0.0198, 0.0022, -0.0002, 0.0022, 0.0022, -0.0002,
                        0.0022, -0.0198, -0.0198, 0.0022
                    ],
                    [
                        -0.0000, -0.0000, 0.0000, 0.0044, 0.0000, 0.0000,
                        -0.0000, -0.0132, -0.0000, -0.0000
                    ],
                    [
                        -0.0000, 0.0000, 0.0000, -0.0000, -0.0000, 0.0000,
                        -0.0000, 0.0000, 0.0000, -0.0000
                    ],
                    [
                        -0.0000, -0.0000, 0.0000, 0.0015, 0.0000, 0.0000,
                        -0.0000, -0.0396, -0.0000, -0.0000
                    ],
                ]),
                torch.tensor([
                    [
                        -0.0198, 0.1780, 0.1780, -0.0198, -0.0198, 0.1780,
                        0.1780, -0.0198, 0.0022, -0.0198
                    ],
                    [
                        0.0000, 0.1274, 0.0000, 0.0000, 0.0000, 0.4878, 0.0000,
                        0.0000, -0.0000, -0.0396
                    ],
                    [
                        -0.0352, 0.3164, 0.3164, -0.0352, -0.0000, 0.0000,
                        0.0000, -0.0000, -0.0000, 0.0000
                    ],
                    [
                        0.0000, 0.4878, 0.0000, 0.0000, 0.0000, 0.1274, 0.0000,
                        0.0000, -0.0000, -0.0132
                    ],
                ]),
                torch.tensor([
                    [
                        -0.0198, 0.0022, 0.0022, -0.0198, -0.0198, 0.0022,
                        -0.0198, 0.1780, 0.1780, -0.0198
                    ],
                    [
                        -0.0000, -0.0000, -0.0000, -0.0132, -0.0000, -0.0000,
                        0.0000, 0.1274, 0.0000, 0.0000
                    ],
                    [
                        0.0000, -0.0000, -0.0000, 0.0000, 0.0000, -0.0000,
                        -0.0352, 0.3164, 0.3164, -0.0352
                    ],
                    [
                        -0.0000, -0.0000, -0.0000, -0.0396, -0.0000, -0.0000,
                        0.0000, 0.4878, 0.0000, 0.0000
                    ],
                ]),
                torch.tensor([
                    [
                        -0.0198, 0.1780, 0.1780, -0.0198, 0.0022, -0.0198,
                        -0.0198, 0.0022, -0.0002, 0.0022
                    ],
                    [
                        0.0000, 0.4878, 0.0000, 0.0000, -0.0000, -0.0396,
                        -0.0000, -0.0000, 0.0000, 0.0015
                    ],
                    [
                        -0.0000, 0.0000, 0.0000, -0.0000, -0.0000, 0.0000,
                        0.0000, -0.0000, 0.0000, -0.0000
                    ],
                    [
                        0.0000, 0.1274, 0.0000, 0.0000, -0.0000, -0.0132,
                        -0.0000, -0.0000, 0.0000, 0.0044
                    ],
                ]),
                torch.tensor([
                    [
                        0.0022, -0.0002, 0.0022, -0.0198, -0.0198, 0.0022,
                        0.0022, -0.0198, -0.0198, 0.0022
                    ],
                    [
                        0.0000, 0.0000, -0.0000, -0.0142, -0.0000, -0.0000,
                        -0.0000, -0.0542, -0.0000, -0.0000
                    ],
                    [
                        -0.0000, 0.0000, 0.0039, -0.0352, -0.0352, 0.0039,
                        0.0000, -0.0000, -0.0000, 0.0000
                    ],
                    [
                        0.0000, 0.0000, -0.0000, -0.0542, -0.0000, -0.0000,
                        -0.0000, -0.0142, -0.0000, -0.0000
                    ],
                ]),
                torch.tensor([
                    [-0.0002, 0.0022, 0.0022, -0.0002],
                    [0.0000, 0.0044, 0.0000, 0.0000],
                    [0.0000, -0.0000, -0.0000, 0.0000],
                    [0.0000, 0.0015, 0.0000, 0.0000],
                ]),
            ],
            1,
        )
        self.assertTrue(approx_equal(values, actual_values))
    def test_solve(self):
        size = 100
        train_x = torch.cat(
            [
                torch.linspace(0, 1, size).unsqueeze(0),
                torch.linspace(0, 0.5, size).unsqueeze(0),
                torch.linspace(0, 0.25, size).unsqueeze(0),
                torch.linspace(0, 1.25, size).unsqueeze(0),
                torch.linspace(0, 1.5, size).unsqueeze(0),
                torch.linspace(0, 1, size).unsqueeze(0),
                torch.linspace(0, 0.5, size).unsqueeze(0),
                torch.linspace(0, 0.25, size).unsqueeze(0),
                torch.linspace(0, 1.25, size).unsqueeze(0),
                torch.linspace(0, 1.25, size).unsqueeze(0),
                torch.linspace(0, 1.5, size).unsqueeze(0),
                torch.linspace(0, 1, size).unsqueeze(0),
            ],
            0,
        ).unsqueeze(-1)
        covar_matrix = RBFKernel()(train_x, train_x).evaluate().view(
            2, 2, 3, size, size)
        piv_chol = pivoted_cholesky.pivoted_cholesky(covar_matrix, 10)
        woodbury_factor, inv_scale, logdet = woodbury.woodbury_factor(
            piv_chol, piv_chol, torch.ones(2, 2, 3, 100), logdet=True)
        actual_logdet = torch.stack([
            mat.logdet() for mat in (piv_chol @ piv_chol.transpose(-1, -2) +
                                     torch.eye(100)).view(-1, 100, 100)
        ], 0).view(2, 2, 3)
        self.assertTrue(approx_equal(logdet, actual_logdet, 2e-4))

        rhs_vector = torch.randn(2, 2, 3, 100, 5)
        shifted_covar_matrix = covar_matrix + torch.eye(size)
        real_solve = torch.cat(
            [
                shifted_covar_matrix[0, 0, 0].inverse().matmul(
                    rhs_vector[0, 0, 0]).unsqueeze(0),
                shifted_covar_matrix[0, 0, 1].inverse().matmul(
                    rhs_vector[0, 0, 1]).unsqueeze(0),
                shifted_covar_matrix[0, 0, 2].inverse().matmul(
                    rhs_vector[0, 0, 2]).unsqueeze(0),
                shifted_covar_matrix[0, 1, 0].inverse().matmul(
                    rhs_vector[0, 1, 0]).unsqueeze(0),
                shifted_covar_matrix[0, 1, 1].inverse().matmul(
                    rhs_vector[0, 1, 1]).unsqueeze(0),
                shifted_covar_matrix[0, 1, 2].inverse().matmul(
                    rhs_vector[0, 1, 2]).unsqueeze(0),
                shifted_covar_matrix[1, 0, 0].inverse().matmul(
                    rhs_vector[1, 0, 0]).unsqueeze(0),
                shifted_covar_matrix[1, 0, 1].inverse().matmul(
                    rhs_vector[1, 0, 1]).unsqueeze(0),
                shifted_covar_matrix[1, 0, 2].inverse().matmul(
                    rhs_vector[1, 0, 2]).unsqueeze(0),
                shifted_covar_matrix[1, 1, 0].inverse().matmul(
                    rhs_vector[1, 1, 0]).unsqueeze(0),
                shifted_covar_matrix[1, 1, 1].inverse().matmul(
                    rhs_vector[1, 1, 1]).unsqueeze(0),
                shifted_covar_matrix[1, 1, 2].inverse().matmul(
                    rhs_vector[1, 1, 2]).unsqueeze(0),
            ],
            0,
        ).view_as(rhs_vector)
        scaled_inv_diag = (inv_scale / torch.ones(2, 3, 100)).unsqueeze(-1)
        approx_solve = woodbury.woodbury_solve(rhs_vector,
                                               piv_chol * scaled_inv_diag,
                                               woodbury_factor,
                                               scaled_inv_diag, inv_scale)

        self.assertTrue(approx_equal(approx_solve, real_solve, 2e-4))
Exemple #19
0
    def test_left_t_interp_on_a_vector(self):
        vector = torch.randn(9)

        res = left_t_interp(self.interp_indices, self.interp_values, vector, 6)
        actual = torch.matmul(self.interp_matrix.transpose(-1, -2), vector)
        self.assertTrue(approx_equal(res, actual))
Exemple #20
0
    def test_left_t_interp_on_a_matrix(self):
        matrix = torch.randn(9, 3)

        res = left_t_interp(self.interp_indices, self.interp_values, matrix, 6)
        actual = torch.matmul(self.interp_matrix.transpose(-1, -2), matrix)
        self.assertTrue(approx_equal(res, actual))