def test_matmul(self):
        lhs = Variable(torch.randn(5, 3), requires_grad=True)
        rhs = Variable(torch.randn(3, 4), requires_grad=True)
        covar = MatmulLazyVariable(lhs, rhs)
        mat = Variable(torch.randn(4, 10))
        res = covar.matmul(mat)

        lhs_clone = Variable(lhs.data.clone(), requires_grad=True)
        rhs_clone = Variable(rhs.data.clone(), requires_grad=True)
        mat_clone = Variable(mat.data.clone())
        actual = lhs_clone.matmul(rhs_clone).matmul(mat_clone)

        self.assertTrue(approx_equal(res.data, actual.data))

        actual.sum().backward()

        res.sum().backward()
        self.assertTrue(approx_equal(lhs.grad.data, lhs_clone.grad.data))
        self.assertTrue(approx_equal(rhs.grad.data, rhs_clone.grad.data))
    def test_matmul(self):
        lhs = torch.randn(5, 3, requires_grad=True)
        rhs = torch.randn(3, 4, requires_grad=True)
        covar = MatmulLazyVariable(lhs, rhs)
        mat = torch.randn(4, 10)
        res = covar.matmul(mat)

        lhs_clone = lhs.clone().detach()
        rhs_clone = rhs.clone().detach()
        mat_clone = mat.clone().detach()
        lhs_clone.requires_grad = True
        rhs_clone.requires_grad = True
        mat_clone.requires_grad = True
        actual = lhs_clone.matmul(rhs_clone).matmul(mat_clone)

        self.assertTrue(approx_equal(res, actual))

        actual.sum().backward()

        res.sum().backward()
        self.assertTrue(approx_equal(lhs.grad, lhs_clone.grad))
        self.assertTrue(approx_equal(rhs.grad, rhs_clone.grad))