def test_root_decomposition(self): # Forward root = NonLazyTensor(self.mat).root_decomposition().root.evaluate() res = root.matmul(root.transpose(-1, -2)) self.assertTrue(approx_equal(res, self.mat)) # Backward sum([mat.trace() for mat in res]).backward() sum([mat.trace() for mat in self.mat_clone]).backward() self.assertTrue(approx_equal(self.mat.grad, self.mat_clone.grad))
def test_root_decomposition(self): # Forward root = NonLazyTensor(self.mat_var).root_decomposition() res = root.matmul(root.transpose(-1, -2)) self.assertTrue(approx_equal(res, self.mat_var)) # Backward res.trace().backward() self.mat_var_clone.trace().backward() self.assertTrue( approx_equal(self.mat_var.grad, self.mat_var_clone.grad))
def test_root_decomposition(self): mat = self._create_mat().detach().requires_grad_(True) mat_clone = mat.detach().clone().requires_grad_(True) # Forward root = NonLazyTensor(mat).root_decomposition().root.evaluate() res = root.matmul(root.transpose(-1, -2)) self.assertAllClose(res, mat) # Backward sum([mat.trace() for mat in res.view(-1, mat.size(-2), mat.size(-1))]).backward() sum([mat.trace() for mat in mat_clone.view(-1, mat.size(-2), mat.size(-1))]).backward() self.assertAllClose(mat.grad, mat_clone.grad)
def test_root_inv_decomposition(self): # Forward probe_vectors = torch.randn(4, 5) test_vectors = torch.randn(4, 5) root = NonLazyTensor(self.mat).root_inv_decomposition( probe_vectors, test_vectors).root.evaluate() res = root.matmul(root.transpose(-1, -2)) actual = self.mat_clone.inverse() self.assertTrue(approx_equal(res, actual)) # Backward res.trace().backward() actual.trace().backward() self.assertTrue(approx_equal(self.mat.grad, self.mat_clone.grad))
def test_root_inv_decomposition(self): # Forward probe_vectors = torch.randn(3, 4, 5) test_vectors = torch.randn(3, 4, 5) root = NonLazyTensor(self.mat).root_inv_decomposition( probe_vectors, test_vectors).root.evaluate() res = root.matmul(root.transpose(-1, -2)) actual = torch.cat( [mat.inverse().unsqueeze(0) for mat in self.mat_clone]) self.assertTrue(approx_equal(res, actual)) # Backward sum([mat.trace() for mat in res]).backward() sum([mat.trace() for mat in actual]).backward() self.assertTrue(approx_equal(self.mat.grad, self.mat_clone.grad))
def test_root_inv_decomposition(self): mat = self._create_mat().detach().requires_grad_(True) mat_clone = mat.detach().clone().requires_grad_(True) # Forward probe_vectors = torch.randn(*mat.shape[:-2], 4, 5) test_vectors = torch.randn(*mat.shape[:-2], 4, 5) root = NonLazyTensor(mat).root_inv_decomposition(probe_vectors, test_vectors).root.evaluate() res = root.matmul(root.transpose(-1, -2)) actual = mat_clone.inverse() self.assertAllClose(res, actual) # Backward sum([mat.trace() for mat in res.view(-1, mat.size(-2), mat.size(-1))]).backward() sum([mat.trace() for mat in actual.view(-1, mat.size(-2), mat.size(-1))]).backward() self.assertAllClose(mat.grad, mat_clone.grad)
def test_root_inv_decomposition(self): # Forward probe_vectors = torch.randn(2, 3, 4, 5) test_vectors = torch.randn(2, 3, 4, 5) root = NonLazyTensor(self.mat).root_inv_decomposition( probe_vectors, test_vectors).root.evaluate() res = root.matmul(root.transpose(-1, -2)) flattened_mats = self.mat_clone.view(-1, *self.mat_clone.shape[-2:]) actual = torch.cat([ mat.inverse().unsqueeze(0) for mat in flattened_mats ]).view_as(self.mat_clone) self.assertTrue(approx_equal(res, actual)) # Backward sum([mat.trace() for mat in res.view(-1, *self.mat.shape[-2:])]).backward() sum([mat.trace() for mat in actual.view(-1, *self.mat.shape[-2:])]).backward() self.assertTrue(approx_equal(self.mat.grad, self.mat_clone.grad))