Ejemplo n.º 1
0
    def test_root_decomposition(self):
        # Forward
        root = NonLazyTensor(self.mat).root_decomposition().root.evaluate()
        res = root.matmul(root.transpose(-1, -2))
        self.assertTrue(approx_equal(res, self.mat))

        # Backward
        sum([mat.trace() for mat in res]).backward()
        sum([mat.trace() for mat in self.mat_clone]).backward()
        self.assertTrue(approx_equal(self.mat.grad, self.mat_clone.grad))
Ejemplo n.º 2
0
    def test_root_decomposition(self):
        # Forward
        root = NonLazyTensor(self.mat_var).root_decomposition()
        res = root.matmul(root.transpose(-1, -2))
        self.assertTrue(approx_equal(res, self.mat_var))

        # Backward
        res.trace().backward()
        self.mat_var_clone.trace().backward()
        self.assertTrue(
            approx_equal(self.mat_var.grad, self.mat_var_clone.grad))
Ejemplo n.º 3
0
    def test_root_decomposition(self):
        mat = self._create_mat().detach().requires_grad_(True)
        mat_clone = mat.detach().clone().requires_grad_(True)

        # Forward
        root = NonLazyTensor(mat).root_decomposition().root.evaluate()
        res = root.matmul(root.transpose(-1, -2))
        self.assertAllClose(res, mat)

        # Backward
        sum([mat.trace() for mat in res.view(-1, mat.size(-2), mat.size(-1))]).backward()
        sum([mat.trace() for mat in mat_clone.view(-1, mat.size(-2), mat.size(-1))]).backward()
        self.assertAllClose(mat.grad, mat_clone.grad)
Ejemplo n.º 4
0
    def test_root_inv_decomposition(self):
        # Forward
        probe_vectors = torch.randn(4, 5)
        test_vectors = torch.randn(4, 5)
        root = NonLazyTensor(self.mat).root_inv_decomposition(
            probe_vectors, test_vectors).root.evaluate()
        res = root.matmul(root.transpose(-1, -2))
        actual = self.mat_clone.inverse()
        self.assertTrue(approx_equal(res, actual))

        # Backward
        res.trace().backward()
        actual.trace().backward()
        self.assertTrue(approx_equal(self.mat.grad, self.mat_clone.grad))
Ejemplo n.º 5
0
    def test_root_inv_decomposition(self):
        # Forward
        probe_vectors = torch.randn(3, 4, 5)
        test_vectors = torch.randn(3, 4, 5)
        root = NonLazyTensor(self.mat).root_inv_decomposition(
            probe_vectors, test_vectors).root.evaluate()
        res = root.matmul(root.transpose(-1, -2))
        actual = torch.cat(
            [mat.inverse().unsqueeze(0) for mat in self.mat_clone])
        self.assertTrue(approx_equal(res, actual))

        # Backward
        sum([mat.trace() for mat in res]).backward()
        sum([mat.trace() for mat in actual]).backward()
        self.assertTrue(approx_equal(self.mat.grad, self.mat_clone.grad))
Ejemplo n.º 6
0
    def test_root_inv_decomposition(self):
        mat = self._create_mat().detach().requires_grad_(True)
        mat_clone = mat.detach().clone().requires_grad_(True)

        # Forward
        probe_vectors = torch.randn(*mat.shape[:-2], 4, 5)
        test_vectors = torch.randn(*mat.shape[:-2], 4, 5)
        root = NonLazyTensor(mat).root_inv_decomposition(probe_vectors, test_vectors).root.evaluate()
        res = root.matmul(root.transpose(-1, -2))
        actual = mat_clone.inverse()
        self.assertAllClose(res, actual)

        # Backward
        sum([mat.trace() for mat in res.view(-1, mat.size(-2), mat.size(-1))]).backward()
        sum([mat.trace() for mat in actual.view(-1, mat.size(-2), mat.size(-1))]).backward()
        self.assertAllClose(mat.grad, mat_clone.grad)
Ejemplo n.º 7
0
    def test_root_inv_decomposition(self):
        # Forward
        probe_vectors = torch.randn(2, 3, 4, 5)
        test_vectors = torch.randn(2, 3, 4, 5)
        root = NonLazyTensor(self.mat).root_inv_decomposition(
            probe_vectors, test_vectors).root.evaluate()
        res = root.matmul(root.transpose(-1, -2))
        flattened_mats = self.mat_clone.view(-1, *self.mat_clone.shape[-2:])
        actual = torch.cat([
            mat.inverse().unsqueeze(0) for mat in flattened_mats
        ]).view_as(self.mat_clone)
        self.assertTrue(approx_equal(res, actual))

        # Backward
        sum([mat.trace()
             for mat in res.view(-1, *self.mat.shape[-2:])]).backward()
        sum([mat.trace()
             for mat in actual.view(-1, *self.mat.shape[-2:])]).backward()
        self.assertTrue(approx_equal(self.mat.grad, self.mat_clone.grad))