def foo(x): @api.custom_transforms def bar(y): return x * y api.defvjp(bar, lambda g, ans, y: x * y) return bar(x)
def test_defvjp_higher_order(self): @api.custom_transforms def foo(x): return np.sin(2. * x) api.defvjp(foo, lambda g, _, x: g * np.cos(x)) ans = api.grad(api.grad(foo))(2.) expected = api.grad(api.grad(np.sin))(2.) self.assertAllClose(ans, expected, check_dtypes=False)
def test_defvjp_use_ans(self): @api.custom_transforms def foo(x, y): return np.sin(x * y) api.defvjp(foo, None, lambda g, ans, x, y: g * x * y + np.cos(ans)) val_ans, grad_ans = api.value_and_grad(foo, 1)(3., 4.) self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False) self.assertAllClose(grad_ans, 3. * 4. + onp.cos(onp.sin(3. * 4)), check_dtypes=False)
def test_defvjp(self): @api.custom_transforms def foo(x, y): return np.sin(x * y) api.defvjp(foo, None, lambda g, _, x, y: g * x * y) val_ans, grad_ans = api.value_and_grad(foo)(3., 4.) self.assertAllClose(val_ans, onp.sin(3. * 4.), check_dtypes=False) self.assertAllClose(grad_ans, 0., check_dtypes=False) ans_0, ans_1 = api.grad(foo, (0, 1))(3., 4.) self.assertAllClose(ans_0, 0., check_dtypes=False) self.assertAllClose(ans_1, 3. * 4., check_dtypes=False)