def test_clear_grad(self): x = Variable(np.array(2.0)) y = 2 * x y.backward() x.clear_grad() y = 3 * x y.backward() self.assertEqual(x.grad.data, np.array(3.0))
def test_two_order_diff(self): x = Variable(np.array(2.0)) y = x**4 - 2*x**2 y.backward(create_graph=True) gx = x.grad x.clear_grad() self.assertEqual(gx.data, np.array(24.0)) gx.backward() self.assertEqual(x.grad.data, np.array(44.0))
def test_higher_derivative(self): x = Variable(np.array(2.0)) y = x**4 - 2 * x**2 y.backward() self.assertEqual(x.grad.data, 24.0) gx = x.grad x.clear_grad() gx.backward() print(gx) self.assertEqual(x.grad.data, 44.0)