def test_solve(self): size = 100 train_x = torch.cat([ torch.linspace(0, 1, size).unsqueeze(0), torch.linspace(0, 0.5, size).unsqueeze(0) ], 0).unsqueeze(-1) covar_matrix = RBFKernel()(train_x, train_x).evaluate() piv_chol = pivoted_cholesky.pivoted_cholesky(covar_matrix, 10) woodbury_factor, inv_scale, logdet = woodbury.woodbury_factor( piv_chol, piv_chol, torch.ones(2, 100), logdet=True) actual_logdet = torch.stack([ mat.logdet() for mat in (piv_chol @ piv_chol.transpose(-1, -2) + torch.eye(100)).view(-1, 100, 100) ], 0).view(2) self.assertTrue(approx_equal(logdet, actual_logdet, 2e-4)) rhs_vector = torch.randn(2, 100, 5) shifted_covar_matrix = covar_matrix + torch.eye(size) real_solve = torch.cat( [ shifted_covar_matrix[0].inverse().matmul( rhs_vector[0]).unsqueeze(0), shifted_covar_matrix[1].inverse().matmul( rhs_vector[1]).unsqueeze(0), ], 0, ) scaled_inv_diag = (inv_scale / torch.ones(2, 100)).unsqueeze(-1) approx_solve = woodbury.woodbury_solve(rhs_vector, piv_chol * scaled_inv_diag, woodbury_factor, scaled_inv_diag, inv_scale) self.assertTrue(approx_equal(approx_solve, real_solve, 2e-4))
def test_solve(self): size = 100 train_x = torch.linspace(0, 1, size) covar_matrix = RBFKernel()(train_x, train_x).evaluate() piv_chol = pivoted_cholesky.pivoted_cholesky(covar_matrix, 10) woodbury_factor, inv_scale, logdet = woodbury.woodbury_factor(piv_chol, piv_chol, torch.ones(100), logdet=True) self.assertTrue(approx_equal(logdet, (piv_chol @ piv_chol.transpose(-1, -2) + torch.eye(100)).logdet(), 2e-4)) rhs_vector = torch.randn(100, 50) shifted_covar_matrix = covar_matrix + torch.eye(size) real_solve = shifted_covar_matrix.inverse().matmul(rhs_vector) scaled_inv_diag = (inv_scale / torch.ones(100)).unsqueeze(-1) approx_solve = woodbury.woodbury_solve( rhs_vector, piv_chol * scaled_inv_diag, woodbury_factor, scaled_inv_diag, inv_scale ) self.assertTrue(approx_equal(approx_solve, real_solve, 2e-4))