示例#1
0
 def _compare_with_saved_model(self, f_jax, *args):
   # Certain ops are converted to ensure an XLA context, e.g.,
   # tf.gather, so that the index-out-of-bounds behavior matches that of
   # JAX. We check that this information is preserved through a savedmodel
   f_tf = jax2tf.convert(f_jax)
   res = f_tf(*args)
   restored_f, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=args)
   res_restored = restored_f(*args)
   self.assertAllClose(res, res_restored)
示例#2
0
  def test_saved_model_simple(self):
    x = np.array([0.7, 0.8], dtype=np.float32)
    def f_jax(x):
      return jnp.sin(x)

    f_tf = jax2tf.convert(f_jax)
    restored_tf, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
    restored_jax = jax2tf.call_tf(restored_tf)
    self.assertAllClose(f_jax(x), restored_jax(x))
示例#3
0
    def test_several_round_trips(self,
                                 f2_function=False,
                                 f2_saved_model=False,
                                 f4_function=False,
                                 f4_saved_model=False):
        x = np.array(.7, dtype=np.float32)

        # f(n)(x) = 2. * x^n
        def f(n):
            def fn(x):
                acc = np.array(2., dtype=x.dtype)
                for i in range(n):
                    acc *= x
                return acc

            return fn

        f2_tf = lambda x: x * jax2tf.convert(f(1))(x)
        if f2_function:
            f2_tf = tf.function(f2_tf, autograph=False)
        if f2_saved_model:
            f2_tf, _ = tf_test_util.SaveAndLoadFunction(f2_tf, input_args=[x])

        self.assertAllClose(f(2)(x), f2_tf(x).numpy())
        _, (g_f2_ft, ) = tf_test_util.ComputeTfValueAndGrad(f2_tf, [x])
        self.assertAllClose(jax.grad(f(2))(x), g_f2_ft.numpy())

        f3_jax = lambda x: x * jax2tf.call_tf(f2_tf)(x)
        self.assertAllClose(f(3)(x), f3_jax(x))
        self.assertAllClose(f(3)(x), jax.jit(f3_jax)(x))
        self.assertAllClose(jax.grad(f(3))(x), jax.grad(f3_jax)(x))

        f4_tf = lambda x: x * jax2tf.convert(f3_jax)(x)
        self.assertAllClose(f(4)(x), f4_tf(x).numpy())
        _, (g_f4_ft, ) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x])
        self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy())

        if f4_function:
            f4_tf = tf.function(f4_tf, autograph=False)
        if f4_saved_model:
            f4_tf, _ = tf_test_util.SaveAndLoadFunction(f4_tf, input_args=[x])
        self.assertAllClose(f(4)(x), f4_tf(x).numpy())
        _, (g_f4_ft, ) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x])
        self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy())
示例#4
0
    def test_round_trip_saved_model_no_gradients(self):
        # Save without gradients
        f_jax = jnp.sum

        x = np.array([0.7, 0.8], dtype=np.float32)
        f_tf = tf_test_util.SaveAndLoadFunction(
            jax2tf.convert(f_jax, with_gradient=True),
            [tf.TensorSpec(x.shape, dtype=x.dtype)],
            save_gradients=False)
        f_rt = jax2tf.call_tf(f_tf)

        self.assertAllClose(f_jax(x), f_rt(x))
示例#5
0
    def test_saved_model_variables(self):
        param = np.array([1., 2.], dtype=np.float32)
        x = np.array([0.7, 0.8], dtype=np.float32)

        def f_jax(param, x):
            return jnp.sin(x) + jnp.cos(param)

        param_v = tf.Variable(param)
        f_tf = jax2tf.convert(f_jax)
        _, restored_model = tf_test_util.SaveAndLoadFunction(
            lambda x: f_tf(param_v, x), input_args=[x], variables=[param_v])
        restored_jax = jax2tf.call_tf(restored_model.f)
        self.assertAllClose(f_jax(param, x), restored_jax(x))
        self.assertAllClose(f_jax(param, x), jax.jit(restored_jax)(x))
示例#6
0
  def test_without_gradient_saved_model(self):
    # Explicitly with_gradient=False
    f_jax = jnp.sum

    x = np.array([0.7, 0.8], dtype=np.float32)
    f_tf, _ = tf_test_util.SaveAndLoadFunction(
        jax2tf.convert(f_jax, with_gradient=False),
        input_args=[x])
    f_rt = jax2tf.call_tf(f_tf)

    self.assertAllClose(f_jax(x), f_rt(x))
    with self.assertRaisesRegex(Exception,
                                "Gradient explicitly disabled.*jax2tf-converted function does not support gradients. Use `with_gradient` parameter to enable gradients"):
      jax.grad(f_rt)(x)
示例#7
0
  def test_saved_model_no_gradients(self):
    # Save without gradients
    f_jax = jnp.sum

    x = np.array([0.7, 0.8], dtype=np.float32)
    f_tf, _ = tf_test_util.SaveAndLoadFunction(
        jax2tf.convert(f_jax, with_gradient=True), input_args=[x],
        save_gradients=False)
    f_rt = jax2tf.call_tf(f_tf)

    self.assertAllClose(f_jax(x), f_rt(x))
    # TODO: clean this up b/191117111: it should fail with a clear error
    # The following results in a confusing error:
    # TypeError: tf.Graph captured an external symbolic tensor.
    with self.assertRaises(TypeError):
      _ = jax.grad(f_rt)(x)
示例#8
0
 def test_repro_193754660(self):
   # Try to reproduce b/193754660. I can't.
   # We have to have tf.function(jax2tf.convert(jax2tf.call_tf(f_tf))).
   # The get_compiler_ir will indeed fail for f_tf. Then we try to use
   # shape inference for f_tf.
   # I thought to use a f_tf that uses an op without shape inference, e.g.,
   # tfxla.gather. If we wash it through a saved_model I expect that shape
   # inference would not work on it. Instead, shape inference works!!!
   x = np.array([0, 1, 2, 3, 4, 5], dtype=np.int32)
   def f_jax(x):
     return x[1]
   f_tf = jax2tf.convert(f_jax)
   f_tf_rt, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x])
   f_jax2 = jax2tf.call_tf(f_tf_rt)
   f_tf2 = jax2tf.convert(f_jax2)
   res = tf.function(f_tf2, autograph=False)(x)
   self.assertAllClose(res.numpy(), f_jax(x))
示例#9
0
  def test_saved_model(self):
    x = np.array([.7, .8], dtype=np.float32)
    def fun_tf(x):
      return tf.math.sin(x)
    def fun_jax(x):
      return jax2tf.call_tf(fun_tf)(x)

    # Now convert and save to SavedModel
    fun_tf_rt = jax2tf.convert(fun_jax)
    res = fun_tf_rt(x)
    self.assertAllClose(np.sin(x), res.numpy())

    res = tf.function(fun_tf_rt, autograph=False)(x)
    self.assertAllClose(np.sin(x), res.numpy())

    res = tf.function(fun_tf_rt, jit_compile=True, autograph=False)(x)
    self.assertAllClose(np.sin(x), res.numpy())

    reloaded_f, _ = tf_test_util.SaveAndLoadFunction(
        fun_tf_rt, input_args=[x])
    res = reloaded_f(x)
    self.assertAllClose(np.sin(x), res.numpy())
示例#10
0
  def test_custom_grad_saved_model(self):
    @jax.custom_vjp
    def f(x):
      return x * x

    # f_fwd: a -> (b, residual)
    def f_fwd(x):
      return f(x), np.float32(3.) * x
    # f_bwd: (residual, CT b) -> [CT a]
    def f_bwd(residual, ct_b):
      return residual * ct_b,

    f.defvjp(f_fwd, f_bwd)
    def g(x):
      return jnp.sum(f(x))

    g_tf, _ = tf_test_util.SaveAndLoadFunction(
        jax2tf.convert(g, with_gradient=True, polymorphic_shapes=["b, ..."]),
        input_signature=[tf.TensorSpec([None], dtype=tf.float32)])
    g_rt = jax2tf.call_tf(g_tf)
    x = np.array([0.7], dtype=np.float32)
    self.assertAllClose(g(x), g_rt(x))
    self.assertAllClose(jax.grad(g)(x), jax.grad(g_rt)(x))