Ejemplo n.º 1
0
  def test_jarrett_jvps(self):
    def f1(x):
      return np.sin(np.sin(np.sin(x)))
    f2 = api.jarrett(f1)

    for x in [3., onp.array([2., 3., 4.])]:
      self.assertAllClose(f1(x), f2(x), check_dtypes=True)

      _, f1_vjp = api.vjp(f1, x)
      _, f2_vjp = api.vjp(f2, x)
      self.assertAllClose(f1_vjp(x), f2_vjp(x), check_dtypes=True)
Ejemplo n.º 2
0
  def test_jarrett_jvps2(self):
    def f1(x, y):
      return np.sin(x) * np.cos(y) * np.sin(x) * np.cos(y)
    f2 = api.jarrett(f1)

    # TODO(mattjj): doesn't work for (3., onp.array([4., 5.]))
    for x, y in [(3., 4.), (onp.array([5., 6.]), onp.array([7., 8.]))]:
      self.assertAllClose(f1(x, y), f2(x, y), check_dtypes=True)

      _, f1_vjp = api.vjp(f1, x, y)
      _, f2_vjp = api.vjp(f2, x, y)
      self.assertAllClose(f1_vjp(y), f2_vjp(y), check_dtypes=True)