Example #1
0
  def test_gradient_nested(self):
    """Save and restore the custom gradient, when combined with other TF code."""
    @jax.custom_jvp
    def f_jax(x):
      return x * x

    @f_jax.defjvp
    def f_jax_jvp(primals, tangents):
      # 3 * x * x_t
      x, = primals
      x_dot, = tangents
      primal_out = f_jax(x)
      tangent_out = x * x_dot * 3.
      return primal_out, tangent_out

    model = tf.Module()
    # After conversion, we wrap with some pure TF code
    model.f = tf.function(lambda x: tf.math.sin(jax2tf.convert(f_jax, with_gradient=True)(x)),
                          autograph=False,
                          input_signature=[tf.TensorSpec([], tf.float32)])
    f_jax_equiv = lambda x: jnp.sin(f_jax(x))
    x = np.array(0.7, dtype=jnp.float32)
    self.assertAllClose(model.f(x), f_jax_equiv(x))
    restored_model = tf_test_util.SaveAndLoadModel(model)
    xv = tf.Variable(x)
    self.assertAllClose(restored_model.f(x), f_jax_equiv(x))
    with tf.GradientTape() as tape:
      y = restored_model.f(xv)
    self.assertAllClose(tape.gradient(y, xv).numpy(),
                        jax.grad(f_jax_equiv)(x))
Example #2
0
  def test_gradient(self):
    """Save and restore the custom gradient."""
    @jax.custom_jvp
    def f_jax(x):
      return x * x

    @f_jax.defjvp
    def f_jax_jvp(primals, tangents):
      # 3 * x * x_t
      x, = primals
      x_dot, = tangents
      primal_out = f_jax(x)
      tangent_out = x * x_dot * 3.
      return primal_out, tangent_out

    model = tf.Module()
    model.f = tf.function(jax2tf.convert(f_jax, with_gradient=True),
                          autograph=False,
                          input_signature=[tf.TensorSpec([], tf.float32)])
    x = np.array(0.7, dtype=jnp.float32)
    self.assertAllClose(model.f(x), f_jax(x))
    restored_model = tf_test_util.SaveAndLoadModel(model)
    xv = tf.Variable(x)
    self.assertAllClose(restored_model.f(x), f_jax(x))
    with tf.GradientTape() as tape:
      y = restored_model.f(xv)
    self.assertAllClose(tape.gradient(y, xv).numpy(),
                        jax.grad(f_jax)(x))
Example #3
0
 def test_eval(self):
     f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
     model = tf.Module()
     model.f = tf.function(jax2tf.convert(f_jax),
                           autograph=False,
                           input_signature=[tf.TensorSpec([], tf.float32)])
     x = np.array(0.7, dtype=jnp.float32)
     self.assertAllClose(model.f(x), f_jax(x))
     restored_model = tf_test_util.SaveAndLoadModel(model)
     self.assertAllClose(restored_model.f(x), f_jax(x))
Example #4
0
  def test_save_without_gradients(self):
    f_jax = lambda x: x * x

    x = np.array(0.7, dtype=jnp.float32)
    model = tf.Module()
    model.f = tf.function(jax2tf.convert(f_jax, with_gradient=True),
                          autograph=False,
                          input_signature=[tf.TensorSpec(x.shape, x.dtype)])

    self.assertAllClose(model.f(x), f_jax(x))
    restored_model = tf_test_util.SaveAndLoadModel(model,
                                                   save_gradients=False)
    self.assertAllClose(restored_model.f(x), f_jax(x))

    xv = tf.Variable(x)
    with tf.GradientTape():
      _ = restored_model.f(xv)
Example #5
0
  def test_gradient_disabled(self):
    f_jax = lambda x: x * x

    model = tf.Module()
    model.f = tf.function(jax2tf.convert(f_jax, with_gradient=False),
                          autograph=False,
                          input_signature=[tf.TensorSpec([], tf.float32)])
    x = np.array(0.7, dtype=jnp.float32)
    self.assertAllClose(model.f(x), f_jax(x))
    restored_model = tf_test_util.SaveAndLoadModel(model)
    xv = tf.Variable(0.7, dtype=jnp.float32)
    self.assertAllClose(restored_model.f(x), f_jax(x))

    with self.assertRaisesRegex(LookupError,
                                "Gradient explicitly disabled.*The jax2tf-converted function does not support gradients"):
      with tf.GradientTape():
        _ = restored_model.f(xv)
Example #6
0
  def test_save_without_embedding_params(self):
    def model_jax(params, inputs):
      return params[0] + params[1] * inputs

    params = (np.array(1.0, dtype=jnp.float32),
              np.array(2.0, dtype=jnp.float32))
    params_vars = tf.nest.map_structure(tf.Variable, params)

    prediction_tf = lambda x: jax2tf.convert(model_jax)(params_vars, x)

    model = tf.Module()
    model._variables = tf.nest.flatten(params_vars)
    model.f = tf.function(prediction_tf, jit_compile=True)

    x = np.array(0.7, dtype=jnp.float32)
    self.assertAllClose(model.f(x), model_jax(params, x))
    restored_model = tf_test_util.SaveAndLoadModel(model,
                                                   save_gradients=False)
    self.assertAllClose(restored_model.f(x), model_jax(params, x))