示例#1
0
    def test_solve(self):
        size = 100
        train_x = torch.cat([
            torch.linspace(0, 1, size).unsqueeze(0),
            torch.linspace(0, 0.5, size).unsqueeze(0)
        ], 0).unsqueeze(-1)
        covar_matrix = RBFKernel()(train_x, train_x).evaluate()
        piv_chol = pivoted_cholesky.pivoted_cholesky(covar_matrix, 10)
        woodbury_factor, inv_scale, logdet = woodbury.woodbury_factor(
            piv_chol, piv_chol, torch.ones(2, 100), logdet=True)
        actual_logdet = torch.stack([
            mat.logdet() for mat in (piv_chol @ piv_chol.transpose(-1, -2) +
                                     torch.eye(100)).view(-1, 100, 100)
        ], 0).view(2)
        self.assertTrue(approx_equal(logdet, actual_logdet, 2e-4))

        rhs_vector = torch.randn(2, 100, 5)
        shifted_covar_matrix = covar_matrix + torch.eye(size)
        real_solve = torch.cat(
            [
                shifted_covar_matrix[0].inverse().matmul(
                    rhs_vector[0]).unsqueeze(0),
                shifted_covar_matrix[1].inverse().matmul(
                    rhs_vector[1]).unsqueeze(0),
            ],
            0,
        )
        scaled_inv_diag = (inv_scale / torch.ones(2, 100)).unsqueeze(-1)
        approx_solve = woodbury.woodbury_solve(rhs_vector,
                                               piv_chol * scaled_inv_diag,
                                               woodbury_factor,
                                               scaled_inv_diag, inv_scale)

        self.assertTrue(approx_equal(approx_solve, real_solve, 2e-4))
    def test_solve(self):
        size = 100
        train_x = torch.linspace(0, 1, size)
        covar_matrix = RBFKernel()(train_x, train_x).evaluate()
        piv_chol = pivoted_cholesky.pivoted_cholesky(covar_matrix, 10)
        woodbury_factor, inv_scale, logdet = woodbury.woodbury_factor(piv_chol, piv_chol, torch.ones(100), logdet=True)
        self.assertTrue(approx_equal(logdet, (piv_chol @ piv_chol.transpose(-1, -2) + torch.eye(100)).logdet(), 2e-4))

        rhs_vector = torch.randn(100, 50)
        shifted_covar_matrix = covar_matrix + torch.eye(size)
        real_solve = shifted_covar_matrix.inverse().matmul(rhs_vector)
        scaled_inv_diag = (inv_scale / torch.ones(100)).unsqueeze(-1)
        approx_solve = woodbury.woodbury_solve(
            rhs_vector, piv_chol * scaled_inv_diag, woodbury_factor, scaled_inv_diag, inv_scale
        )

        self.assertTrue(approx_equal(approx_solve, real_solve, 2e-4))