def test_pivoted_cholesky(self, max_iter=3):
        mat = self._create_mat().detach().requires_grad_(True)
        mat.register_hook(_ensure_symmetric_grad)
        mat_copy = mat.detach().clone().requires_grad_(True)
        mat_copy.register_hook(_ensure_symmetric_grad)

        # Forward (with function)
        res, pivots = pivoted_cholesky(mat, rank=max_iter, return_pivots=True)

        # Forward (manual pivoting, actual Cholesky)
        inverse_pivots = inverse_permutation(pivots)
        # Apply pivoting
        pivoted_mat_copy = apply_permutation(mat_copy, pivots, pivots)
        # Compute Cholesky
        actual_pivoted = psd_safe_cholesky(pivoted_mat_copy)[..., :max_iter]
        # Undo pivoting
        actual = apply_permutation(actual_pivoted,
                                   left_permutation=inverse_pivots)

        self.assertAllClose(res, actual)

        # Backward
        grad_output = torch.randn_like(res)
        res.backward(gradient=grad_output)
        actual.backward(gradient=grad_output)
        self.assertAllClose(mat.grad, mat_copy.grad)
Ejemplo n.º 2
0
 def test_apply_permutation_right_only(self):
     A = self._gen_test_psd()
     right_permutation = torch.tensor([1, 0])
     res = apply_permutation(A, right_permutation=right_permutation)
     self.assertAllClose(
         res,
         torch.tensor([[[-0.75, 0.25], [2.25, -0.75]],
                       [[1.2, 1.0], [0.5, 1.2]]]))
Ejemplo n.º 3
0
 def test_apply_permutation_left_only(self):
     A = self._gen_test_psd()
     left_permutation = torch.tensor([[0, 1], [1, 0]])
     res = apply_permutation(A, left_permutation=left_permutation)
     self.assertAllClose(
         res,
         torch.tensor([[[0.25, -0.75], [-0.75, 2.25]],
                       [[1.2, 0.5], [1.0, 1.2]]]))
Ejemplo n.º 4
0
 def test_apply_permutation_left_and_right(self):
     A = self._gen_test_psd()
     left_permutation = torch.tensor([[0, 1], [1, 0]])
     right_permutation = torch.tensor([1, 0])
     res = apply_permutation(A, left_permutation, right_permutation)
     self.assertAllClose(
         res,
         torch.tensor([[[-0.75, 0.25], [2.25, -0.75]],
                       [[0.5, 1.2], [1.2, 1.0]]]))