def test_function_factory(self): # 1d diag_var1 = Variable(diag, requires_grad=True) diag_var2 = Variable(diag, requires_grad=True) test_mat = torch.Tensor([3, 4, 5]) diag_lv = DiagLazyVariable(diag_var1) diag_ev = DiagLazyVariable(diag_var2).evaluate() # Forward res = diag_lv.matmul(Variable(test_mat)) actual = torch.matmul(diag_ev, Variable(test_mat)) self.assertLess(torch.norm(res.data - actual.data), 1e-4) # Backward res.sum().backward() actual.sum().backward() self.assertLess( torch.norm(diag_var1.grad.data - diag_var2.grad.data), 1e-3, ) # 2d diag_var1 = Variable(diag, requires_grad=True) diag_var2 = Variable(diag, requires_grad=True) test_mat = torch.eye(3) diag_lv = DiagLazyVariable(diag_var1) diag_ev = DiagLazyVariable(diag_var2).evaluate() # Forward res = diag_lv.matmul(Variable(test_mat)) actual = torch.matmul(diag_ev, Variable(test_mat)) self.assertLess(torch.norm(res.data - actual.data), 1e-4) # Backward res.sum().backward() actual.sum().backward() self.assertLess( torch.norm(diag_var1.grad.data - diag_var2.grad.data), 1e-3, )
def test_batch_function_factory(): # 2d diag_var1 = Variable(diag.repeat(5, 1), requires_grad=True) diag_var2 = Variable(diag.repeat(5, 1), requires_grad=True) test_mat = torch.eye(3).repeat(5, 1, 1) diag_lv = DiagLazyVariable(diag_var1) diag_ev = DiagLazyVariable(diag_var2).evaluate() # Forward res = diag_lv.matmul(Variable(test_mat)) actual = torch.matmul(diag_ev, Variable(test_mat)) assert torch.norm(res.data - actual.data) < 1e-4 # Backward res.sum().backward() actual.sum().backward() assert torch.norm(diag_var1.grad.data - diag_var2.grad.data) < 1e-3