Exemple #1
0
    def test_several_round_trips(self,
                                 f2_function=False,
                                 f2_saved_model=False,
                                 f4_function=False,
                                 f4_saved_model=False):
        x = np.array(.7, dtype=np.float32)

        # f(n)(x) = 2. * x^n
        def f(n):
            def fn(x):
                acc = np.array(2., dtype=x.dtype)
                for i in range(n):
                    acc *= x
                return acc

            return fn

        f2_tf = lambda x: x * jax2tf.convert(f(1))(x)
        if f2_function:
            f2_tf = tf.function(f2_tf, autograph=False)
        if f2_saved_model:
            f2_tf, _ = tf_test_util.SaveAndLoadFunction(f2_tf, input_args=[x])

        self.assertAllClose(f(2)(x), f2_tf(x).numpy())
        _, (g_f2_ft, ) = tf_test_util.ComputeTfValueAndGrad(f2_tf, [x])
        self.assertAllClose(jax.grad(f(2))(x), g_f2_ft.numpy())

        f3_jax = lambda x: x * jax2tf.call_tf(f2_tf)(x)
        self.assertAllClose(f(3)(x), f3_jax(x))
        self.assertAllClose(f(3)(x), jax.jit(f3_jax)(x))
        self.assertAllClose(jax.grad(f(3))(x), jax.grad(f3_jax)(x))

        f4_tf = lambda x: x * jax2tf.convert(f3_jax)(x)
        self.assertAllClose(f(4)(x), f4_tf(x).numpy())
        _, (g_f4_ft, ) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x])
        self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy())

        if f4_function:
            f4_tf = tf.function(f4_tf, autograph=False)
        if f4_saved_model:
            f4_tf, _ = tf_test_util.SaveAndLoadFunction(f4_tf, input_args=[x])
        self.assertAllClose(f(4)(x), f4_tf(x).numpy())
        _, (g_f4_ft, ) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x])
        self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy())
Exemple #2
0
    def test_alternate(self):
        # Alternate sin/cos with sin in TF and cos in JAX
        f_tf_inner = tf.math.sin

        def f_jax(x_jax):
            y_jax = jnp.cos(x_jax)
            z_jax = jax2tf.call_tf(f_tf_inner)(y_jax)
            return jnp.cos(z_jax)

        def f_tf_outer(x_tf):
            y_tf = tf.math.sin(x_tf)
            z_tf = jax2tf.convert(f_jax)(y_tf)
            return tf.math.sin(z_tf)

        x = np.float32(0.7)

        self.assertAllClose(np.sin(np.cos(np.sin(np.cos(np.sin(x))))),
                            f_tf_outer(x).numpy())
        xv = tf.Variable(x)
        with tf.GradientTape() as tape:
            res = f_tf_outer(xv)
        g_tf = tape.gradient(res, xv)
        _, gf = tf_test_util.ComputeTfValueAndGrad(f_tf_outer, (x, ))
        # Eager
        expected_res = np.sin(np.cos(np.sin(np.cos(np.sin(x)))))
        self.assertAllClose(expected_res, f_tf_outer(x).numpy())

        # Gradient
        expected_grad = (np.cos(np.cos(np.sin(np.cos(np.sin(x))))) *
                         np.sin(np.sin(np.cos(np.sin(x)))) *
                         np.cos(np.cos(np.sin(x))) * np.sin(np.sin(x)) *
                         np.cos(x))
        self.assertAllClose(expected_grad, g_tf.numpy())

        # Graph
        self.assertAllClose(
            expected_res,
            tf.function(f_tf_outer, autograph=False)(x).numpy())

        # Compiled
        self.assertAllClose(
            expected_res,
            tf.function(f_tf_outer, autograph=False,
                        jit_compile=True)(x).numpy())
Exemple #3
0
  def test_gradients_int_argument(self, with_function=True):
    # https://github.com/google/jax/issues/6975
    # Also issue #6975.
    # An expanded version of test_gradients_unused_argument
    state = dict(
        float_used=np.array([0.7, 0.9], dtype=np.float32),
        float_passthrough=np.float16(1.),
        float_unused=np.array([1.1, 2.2, 3.3], dtype=np.float32),
        int_used=np.int16(5),
        int_passthrough=np.int8(7),
        int_unused=np.array([1, 2, 3], dtype=np.uint32),
        bool_used=np.array([True, False, False, True], dtype=np.bool_),
        bool_passthrough=np.array([True, False, False, True, False], dtype=np.bool_),
        bool_unused=np.array([[True, False], [False, True]], dtype=np.bool_),
    )
    def jax_f(state):
      res = dict(state,
                 float_used=2. * state["float_used"],
                 int_used=3 * state["int_used"],
                 bool_used=(state["bool_used"] == state["bool_used"]))
      del res["float_unused"]
      del res["int_unused"]
      del res["bool_unused"]
      return res

    args = (state,)
    res_jax = jax_f(*args)
    # Native JAX AD
    vjp_jax_fun, args_vjp = tf_test_util.TransformJaxVJP(jax_f, args, res_jax)
    grad_jax, = vjp_jax_fun(*args_vjp)

    def compare_with_overrides(*, what, expected, **expected_overrides):
      what_keys = set(what.keys())
      expected_keys = set(expected.keys())
      self.assertEqual(what_keys, expected_keys)
      for k, w in what.items():
        e = expected[k]
        if k in expected_overrides:
          if expected_overrides[k] == "ZERO":
            e = np.zeros_like(w)
          elif expected_overrides[k] == "ZERO_INT32":
            e = np.zeros(np.shape(w), dtype=np.int32)
          elif expected_overrides[k] == "ONE":
            e = np.ones_like(w)
          else:
            e = expected_overrides[k]

        if e is None:
          self.assertIsNone(w, msg=k)
        else:
          self.assertIsNotNone(w, msg=k)
        w = w.numpy() if isinstance(w, tf.Tensor) else e
        e = e.numpy() if isinstance(e, tf.Tensor) else e
        try:
          self.assertAllClose(e, w, err_msg=k)
        except:
          print(f"Failed at {k}")
          raise


    # compare_with_overrides(g_jax, {},
    #   bool_passthrough=np.zeros(state["bool_passthrough"].shape, dtype=dtypes.float0),
    #   bool_unused=np.zeros(state["bool_unused"].shape, dtype=dtypes.float0),
    #   bool_used=np.zeros(state["bool_used"].shape, dtype=dtypes.float0),
    #   float_passthrough=np.ones_like(state["float_passthrough"]),
    #   float_unused=np.zeros_like(state["float_unused"]),
    #   float_used=np.ones_like(state["float_used"]) * np.array(2., dtype=state["float_used"].dtype),
    #   int_passthrough=np.zeros(state["int_passthrough"].shape, dtype=dtypes.float0),
    #   int_unused=np.zeros(state["int_unused"].shape, dtype=dtypes.float0),
    #   int_used=np.zeros(state["int_used"].shape, dtype=dtypes.float0))


    # Now native TF gradients, only to test how native TF AD works
    _, (grad_tf_0,) = tf_test_util.ComputeTfValueAndGrad(
        jax_f, args, unconnected_gradients=tf.UnconnectedGradients.ZERO)
    compare_with_overrides(what=grad_tf_0,
                           expected=grad_jax,
                           float_unused="ZERO",
                           bool_used="ZERO", bool_passthrough="ONE", bool_unused="ZERO",
                           int_used="ZERO", int_passthrough="ONE", int_unused="ZERO")

    _, (grad_tf_None,) = tf_test_util.ComputeTfValueAndGrad(
        jax_f, args,
        unconnected_gradients=tf.UnconnectedGradients.NONE)
    compare_with_overrides(what=grad_tf_None,
                           expected=grad_tf_0,
                           float_unused=None, int_used=None, int_unused=None,
                           bool_used=None, bool_unused=None)

    f_tf_jax = jax2tf.convert(jax_f)
    if with_function:
      f_tf_jax = tf.function(f_tf_jax, autograph=False)

    _, (grad_tf_jax_0,) = tf_test_util.ComputeTfValueAndGrad(f_tf_jax, args)
    # Same results as TF native AD with tf.UnconnectedGradients.ZERO
    compare_with_overrides(what=grad_tf_jax_0,
                           expected=grad_tf_0,
                           int_passthrough="ZERO", bool_passthrough="ZERO")

    _, (grad_tf_jax_None,) = tf_test_util.ComputeTfValueAndGrad(
        f_tf_jax, args,
        unconnected_gradients=tf.UnconnectedGradients.NONE)
    compare_with_overrides(what=grad_tf_jax_None,
                           expected=grad_tf_0,
                           int_used=None, int_passthrough=None, int_unused=None,
                           bool_unused=None, bool_used=None, bool_passthrough=None)

    # Not convert the JAX gradient function
    tf_vjp_jax_fun = jax2tf.convert(vjp_jax_fun)
    grad_tf_vjp_jax, = tf_vjp_jax_fun(*args_vjp)
    compare_with_overrides(what=grad_tf_vjp_jax,
                           expected=grad_tf_0,
                           bool_passthrough="ZERO_INT32",
                           bool_unused="ZERO_INT32", bool_used="ZERO_INT32",
                           int_passthrough="ZERO_INT32", int_unused="ZERO_INT32",
                           int_used="ZERO_INT32")