def _compare_with_saved_model(self, f_jax, *args): # Certain ops are converted to ensure an XLA context, e.g., # tf.gather, so that the index-out-of-bounds behavior matches that of # JAX. We check that this information is preserved through a savedmodel f_tf = jax2tf.convert(f_jax) res = f_tf(*args) restored_f, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=args) res_restored = restored_f(*args) self.assertAllClose(res, res_restored)
def test_saved_model_simple(self): x = np.array([0.7, 0.8], dtype=np.float32) def f_jax(x): return jnp.sin(x) f_tf = jax2tf.convert(f_jax) restored_tf, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x]) restored_jax = jax2tf.call_tf(restored_tf) self.assertAllClose(f_jax(x), restored_jax(x))
def test_several_round_trips(self, f2_function=False, f2_saved_model=False, f4_function=False, f4_saved_model=False): x = np.array(.7, dtype=np.float32) # f(n)(x) = 2. * x^n def f(n): def fn(x): acc = np.array(2., dtype=x.dtype) for i in range(n): acc *= x return acc return fn f2_tf = lambda x: x * jax2tf.convert(f(1))(x) if f2_function: f2_tf = tf.function(f2_tf, autograph=False) if f2_saved_model: f2_tf, _ = tf_test_util.SaveAndLoadFunction(f2_tf, input_args=[x]) self.assertAllClose(f(2)(x), f2_tf(x).numpy()) _, (g_f2_ft, ) = tf_test_util.ComputeTfValueAndGrad(f2_tf, [x]) self.assertAllClose(jax.grad(f(2))(x), g_f2_ft.numpy()) f3_jax = lambda x: x * jax2tf.call_tf(f2_tf)(x) self.assertAllClose(f(3)(x), f3_jax(x)) self.assertAllClose(f(3)(x), jax.jit(f3_jax)(x)) self.assertAllClose(jax.grad(f(3))(x), jax.grad(f3_jax)(x)) f4_tf = lambda x: x * jax2tf.convert(f3_jax)(x) self.assertAllClose(f(4)(x), f4_tf(x).numpy()) _, (g_f4_ft, ) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x]) self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy()) if f4_function: f4_tf = tf.function(f4_tf, autograph=False) if f4_saved_model: f4_tf, _ = tf_test_util.SaveAndLoadFunction(f4_tf, input_args=[x]) self.assertAllClose(f(4)(x), f4_tf(x).numpy()) _, (g_f4_ft, ) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x]) self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy())
def test_round_trip_saved_model_no_gradients(self): # Save without gradients f_jax = jnp.sum x = np.array([0.7, 0.8], dtype=np.float32) f_tf = tf_test_util.SaveAndLoadFunction( jax2tf.convert(f_jax, with_gradient=True), [tf.TensorSpec(x.shape, dtype=x.dtype)], save_gradients=False) f_rt = jax2tf.call_tf(f_tf) self.assertAllClose(f_jax(x), f_rt(x))
def test_saved_model_variables(self): param = np.array([1., 2.], dtype=np.float32) x = np.array([0.7, 0.8], dtype=np.float32) def f_jax(param, x): return jnp.sin(x) + jnp.cos(param) param_v = tf.Variable(param) f_tf = jax2tf.convert(f_jax) _, restored_model = tf_test_util.SaveAndLoadFunction( lambda x: f_tf(param_v, x), input_args=[x], variables=[param_v]) restored_jax = jax2tf.call_tf(restored_model.f) self.assertAllClose(f_jax(param, x), restored_jax(x)) self.assertAllClose(f_jax(param, x), jax.jit(restored_jax)(x))
def test_without_gradient_saved_model(self): # Explicitly with_gradient=False f_jax = jnp.sum x = np.array([0.7, 0.8], dtype=np.float32) f_tf, _ = tf_test_util.SaveAndLoadFunction( jax2tf.convert(f_jax, with_gradient=False), input_args=[x]) f_rt = jax2tf.call_tf(f_tf) self.assertAllClose(f_jax(x), f_rt(x)) with self.assertRaisesRegex(Exception, "Gradient explicitly disabled.*jax2tf-converted function does not support gradients. Use `with_gradient` parameter to enable gradients"): jax.grad(f_rt)(x)
def test_saved_model_no_gradients(self): # Save without gradients f_jax = jnp.sum x = np.array([0.7, 0.8], dtype=np.float32) f_tf, _ = tf_test_util.SaveAndLoadFunction( jax2tf.convert(f_jax, with_gradient=True), input_args=[x], save_gradients=False) f_rt = jax2tf.call_tf(f_tf) self.assertAllClose(f_jax(x), f_rt(x)) # TODO: clean this up b/191117111: it should fail with a clear error # The following results in a confusing error: # TypeError: tf.Graph captured an external symbolic tensor. with self.assertRaises(TypeError): _ = jax.grad(f_rt)(x)
def test_repro_193754660(self): # Try to reproduce b/193754660. I can't. # We have to have tf.function(jax2tf.convert(jax2tf.call_tf(f_tf))). # The get_compiler_ir will indeed fail for f_tf. Then we try to use # shape inference for f_tf. # I thought to use a f_tf that uses an op without shape inference, e.g., # tfxla.gather. If we wash it through a saved_model I expect that shape # inference would not work on it. Instead, shape inference works!!! x = np.array([0, 1, 2, 3, 4, 5], dtype=np.int32) def f_jax(x): return x[1] f_tf = jax2tf.convert(f_jax) f_tf_rt, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x]) f_jax2 = jax2tf.call_tf(f_tf_rt) f_tf2 = jax2tf.convert(f_jax2) res = tf.function(f_tf2, autograph=False)(x) self.assertAllClose(res.numpy(), f_jax(x))
def test_saved_model(self): x = np.array([.7, .8], dtype=np.float32) def fun_tf(x): return tf.math.sin(x) def fun_jax(x): return jax2tf.call_tf(fun_tf)(x) # Now convert and save to SavedModel fun_tf_rt = jax2tf.convert(fun_jax) res = fun_tf_rt(x) self.assertAllClose(np.sin(x), res.numpy()) res = tf.function(fun_tf_rt, autograph=False)(x) self.assertAllClose(np.sin(x), res.numpy()) res = tf.function(fun_tf_rt, jit_compile=True, autograph=False)(x) self.assertAllClose(np.sin(x), res.numpy()) reloaded_f, _ = tf_test_util.SaveAndLoadFunction( fun_tf_rt, input_args=[x]) res = reloaded_f(x) self.assertAllClose(np.sin(x), res.numpy())
def test_custom_grad_saved_model(self): @jax.custom_vjp def f(x): return x * x # f_fwd: a -> (b, residual) def f_fwd(x): return f(x), np.float32(3.) * x # f_bwd: (residual, CT b) -> [CT a] def f_bwd(residual, ct_b): return residual * ct_b, f.defvjp(f_fwd, f_bwd) def g(x): return jnp.sum(f(x)) g_tf, _ = tf_test_util.SaveAndLoadFunction( jax2tf.convert(g, with_gradient=True, polymorphic_shapes=["b, ..."]), input_signature=[tf.TensorSpec([None], dtype=tf.float32)]) g_rt = jax2tf.call_tf(g_tf) x = np.array([0.7], dtype=np.float32) self.assertAllClose(g(x), g_rt(x)) self.assertAllClose(jax.grad(g)(x), jax.grad(g_rt)(x))