def test_gradient_nested(self): """Save and restore the custom gradient, when combined with other TF code.""" @jax.custom_jvp def f_jax(x): return x * x @f_jax.defjvp def f_jax_jvp(primals, tangents): # 3 * x * x_t x, = primals x_dot, = tangents primal_out = f_jax(x) tangent_out = x * x_dot * 3. return primal_out, tangent_out model = tf.Module() # After conversion, we wrap with some pure TF code model.f = tf.function(lambda x: tf.math.sin(jax2tf.convert(f_jax, with_gradient=True)(x)), autograph=False, input_signature=[tf.TensorSpec([], tf.float32)]) f_jax_equiv = lambda x: jnp.sin(f_jax(x)) x = np.array(0.7, dtype=jnp.float32) self.assertAllClose(model.f(x), f_jax_equiv(x)) restored_model = tf_test_util.SaveAndLoadModel(model) xv = tf.Variable(x) self.assertAllClose(restored_model.f(x), f_jax_equiv(x)) with tf.GradientTape() as tape: y = restored_model.f(xv) self.assertAllClose(tape.gradient(y, xv).numpy(), jax.grad(f_jax_equiv)(x))
def test_gradient(self): """Save and restore the custom gradient.""" @jax.custom_jvp def f_jax(x): return x * x @f_jax.defjvp def f_jax_jvp(primals, tangents): # 3 * x * x_t x, = primals x_dot, = tangents primal_out = f_jax(x) tangent_out = x * x_dot * 3. return primal_out, tangent_out model = tf.Module() model.f = tf.function(jax2tf.convert(f_jax, with_gradient=True), autograph=False, input_signature=[tf.TensorSpec([], tf.float32)]) x = np.array(0.7, dtype=jnp.float32) self.assertAllClose(model.f(x), f_jax(x)) restored_model = tf_test_util.SaveAndLoadModel(model) xv = tf.Variable(x) self.assertAllClose(restored_model.f(x), f_jax(x)) with tf.GradientTape() as tape: y = restored_model.f(xv) self.assertAllClose(tape.gradient(y, xv).numpy(), jax.grad(f_jax)(x))
def test_eval(self): f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x))) model = tf.Module() model.f = tf.function(jax2tf.convert(f_jax), autograph=False, input_signature=[tf.TensorSpec([], tf.float32)]) x = np.array(0.7, dtype=jnp.float32) self.assertAllClose(model.f(x), f_jax(x)) restored_model = tf_test_util.SaveAndLoadModel(model) self.assertAllClose(restored_model.f(x), f_jax(x))
def test_save_without_gradients(self): f_jax = lambda x: x * x x = np.array(0.7, dtype=jnp.float32) model = tf.Module() model.f = tf.function(jax2tf.convert(f_jax, with_gradient=True), autograph=False, input_signature=[tf.TensorSpec(x.shape, x.dtype)]) self.assertAllClose(model.f(x), f_jax(x)) restored_model = tf_test_util.SaveAndLoadModel(model, save_gradients=False) self.assertAllClose(restored_model.f(x), f_jax(x)) xv = tf.Variable(x) with tf.GradientTape(): _ = restored_model.f(xv)
def test_gradient_disabled(self): f_jax = lambda x: x * x model = tf.Module() model.f = tf.function(jax2tf.convert(f_jax, with_gradient=False), autograph=False, input_signature=[tf.TensorSpec([], tf.float32)]) x = np.array(0.7, dtype=jnp.float32) self.assertAllClose(model.f(x), f_jax(x)) restored_model = tf_test_util.SaveAndLoadModel(model) xv = tf.Variable(0.7, dtype=jnp.float32) self.assertAllClose(restored_model.f(x), f_jax(x)) with self.assertRaisesRegex(LookupError, "Gradient explicitly disabled.*The jax2tf-converted function does not support gradients"): with tf.GradientTape(): _ = restored_model.f(xv)
def test_save_without_embedding_params(self): def model_jax(params, inputs): return params[0] + params[1] * inputs params = (np.array(1.0, dtype=jnp.float32), np.array(2.0, dtype=jnp.float32)) params_vars = tf.nest.map_structure(tf.Variable, params) prediction_tf = lambda x: jax2tf.convert(model_jax)(params_vars, x) model = tf.Module() model._variables = tf.nest.flatten(params_vars) model.f = tf.function(prediction_tf, jit_compile=True) x = np.array(0.7, dtype=jnp.float32) self.assertAllClose(model.f(x), model_jax(params, x)) restored_model = tf_test_util.SaveAndLoadModel(model, save_gradients=False) self.assertAllClose(restored_model.f(x), model_jax(params, x))