def test_jarrett_jvps(self): def f1(x): return np.sin(np.sin(np.sin(x))) f2 = api.jarrett(f1) for x in [3., onp.array([2., 3., 4.])]: self.assertAllClose(f1(x), f2(x), check_dtypes=True) _, f1_vjp = api.vjp(f1, x) _, f2_vjp = api.vjp(f2, x) self.assertAllClose(f1_vjp(x), f2_vjp(x), check_dtypes=True)
def test_jarrett_jvps2(self): def f1(x, y): return np.sin(x) * np.cos(y) * np.sin(x) * np.cos(y) f2 = api.jarrett(f1) # TODO(mattjj): doesn't work for (3., onp.array([4., 5.])) for x, y in [(3., 4.), (onp.array([5., 6.]), onp.array([7., 8.]))]: self.assertAllClose(f1(x, y), f2(x, y), check_dtypes=True) _, f1_vjp = api.vjp(f1, x, y) _, f2_vjp = api.vjp(f2, x, y) self.assertAllClose(f1_vjp(y), f2_vjp(y), check_dtypes=True)