def test_defvjp_all(self): @api.custom_transforms def foo(x): return np.sin(x) api.defvjp_all(foo, lambda x: (np.sin(x), lambda g: (g * x, ))) val_ans, grad_ans = api.value_and_grad(foo)(3.) self.assertAllClose(val_ans, onp.sin(3.), check_dtypes=False) self.assertAllClose(grad_ans, 3., check_dtypes=False)
def test_custom_vjp_zeros(self): @api.custom_transforms def f(x, y): return 2 * x, 3 * y def f_vjp(x, y): return (2 * x, 3 * y), lambda ts: (4 * ts[0], 5 * ts[1]) api.defvjp_all(f, f_vjp, ) api.grad(lambda x, y: f(x, y)[0])(1., 2.) # doesn't crash