def test_matmul(self): lhs = Variable(torch.randn(5, 3), requires_grad=True) rhs = Variable(torch.randn(3, 4), requires_grad=True) covar = MatmulLazyVariable(lhs, rhs) mat = Variable(torch.randn(4, 10)) res = covar.matmul(mat) lhs_clone = Variable(lhs.data.clone(), requires_grad=True) rhs_clone = Variable(rhs.data.clone(), requires_grad=True) mat_clone = Variable(mat.data.clone()) actual = lhs_clone.matmul(rhs_clone).matmul(mat_clone) self.assertTrue(approx_equal(res.data, actual.data)) actual.sum().backward() res.sum().backward() self.assertTrue(approx_equal(lhs.grad.data, lhs_clone.grad.data)) self.assertTrue(approx_equal(rhs.grad.data, rhs_clone.grad.data))
def test_matmul(self): lhs = torch.randn(5, 3, requires_grad=True) rhs = torch.randn(3, 4, requires_grad=True) covar = MatmulLazyVariable(lhs, rhs) mat = torch.randn(4, 10) res = covar.matmul(mat) lhs_clone = lhs.clone().detach() rhs_clone = rhs.clone().detach() mat_clone = mat.clone().detach() lhs_clone.requires_grad = True rhs_clone.requires_grad = True mat_clone.requires_grad = True actual = lhs_clone.matmul(rhs_clone).matmul(mat_clone) self.assertTrue(approx_equal(res, actual)) actual.sum().backward() res.sum().backward() self.assertTrue(approx_equal(lhs.grad, lhs_clone.grad)) self.assertTrue(approx_equal(rhs.grad, rhs_clone.grad))