def test_matmul(self): root = torch.randn(5, 3, requires_grad=True) covar = RootLazyVariable(root) mat = torch.eye(5) res = covar.matmul(mat) root_clone = root.clone().detach() root_clone.requires_grad = True mat_clone = mat.clone().detach() mat_clone.requires_grad = True actual = root_clone.matmul(root_clone.transpose(-1, -2)).matmul(mat_clone) self.assertTrue(approx_equal(res, actual)) gradient = torch.randn(5, 5) actual.backward(gradient=gradient) res.backward(gradient=gradient) self.assertTrue(approx_equal(root.grad, root_clone.grad))