def test_grad_pytree(self, with_jit=False): def fun_tf(x: Dict, y: Tuple) -> Tuple: return (x["first"] * x["second"] + 3. * y[0] + 4. * y[1]) x = dict(first=np.float32(3.), second=np.float32(4.)) y = (np.float32(5.), np.float32(6.)) grad_x = _maybe_jit(with_jit, jax.grad(jax2tf.call_tf(fun_tf)))(x, y) self.assertAllClose(dict(first=np.float32(4.), second=np.float32(3.)), grad_x)
def test_with_value_capture(self, with_jit=True): outer_val = np.array(3., dtype=np.float32) def fun_tf(x): return x * outer_val + 1. x = np.float32(2.) res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x) self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
def test_eval_devicearray_no_copy(self): if jtu.device_under_test() != "cpu": # TODO(necula): add tests for GPU and TPU raise unittest.SkipTest("no_copy test works only on CPU") # For DeviceArray zero-copy works even if not aligned x = jnp.ones((3, 3), dtype=np.float32) res = jax2tf.call_tf(lambda x: x)(x) self.assertAllClose(x, res) self.assertTrue(np.shares_memory(x, res))
def test_bool(self, with_jit=False): def fun_tf(x, y): return tf.math.logical_and(x, y) x = np.array([True, False, True, False], dtype=np.bool_) y = np.array([True, True, False, False], dtype=np.bool_) res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x, y) self.assertAllClose( np.array([True, False, False, False], dtype=np.bool_), res)
def test_with_tensor_capture_x64(self, with_jit=True): outer_tensor = tf.constant(3., dtype=np.float64) def fun_tf(x): return x * tf.cast(outer_tensor * 3.14, tf.float32) + 1. x = np.float32(2.) res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x) self.assertAllClose(x * 3. * 3.14 + 1., res, check_dtypes=False)
def test_with_var_read(self, with_jit=True): outer_var = tf.Variable(3., dtype=np.float32) def fun_tf(x): return x * outer_var + 1. x = np.float32(2.) res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x) self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
def test_with_tensor_capture(self, with_jit=False): outer_tensor = tf.constant(3., dtype=np.float32) def fun_tf(x): return x * outer_tensor + 1. x = np.float32(2.) res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x) self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
def test_saved_model_simple(self): x = np.array([0.7, 0.8], dtype=np.float32) def f_jax(x): return jnp.sin(x) f_tf = jax2tf.convert(f_jax) restored_tf, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x]) restored_jax = jax2tf.call_tf(restored_tf) self.assertAllClose(f_jax(x), restored_jax(x))
def test_eval_pytree(self, with_jit=True): def fun_tf(x: Dict, y: Tuple) -> Tuple: return (x["first"] * x["second"], y[0] + y[1]) x = dict(first=np.float32(3.), second=np.float32(4.)) y = (np.float64(5.), np.float64(6.)) fun_jax = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf)) res = fun_jax(x, y) self.assertAllClose((np.float32(12.), np.float64(11.)), res)
def test_function_compile_time_constant_inputs(self): # Call a function for which shape inference does not give an output # shape. x = np.array([1, 2, 3], dtype=np.int32) def fun_tf(x): # x:i32[3] # Indexing with a dynamic slice makes the TF shape inference return # a partially known shape. end_idx = x[1] res = x[0:end_idx] return res # Call in eager mode. Should work! res1 = jax2tf.call_tf(fun_tf)(x) self.assertAllClose(x[0:x[1]], res1) # Now under jit, should fail because the function is not compileable with self.assertRaisesRegex(ValueError, "Compiled TensorFlow function has unexpected parameter types"): fun_jax = jax.jit(jax2tf.call_tf(fun_tf)) fun_jax(x)
def test_eval_non_compileable_dynamic_shape(self): # Check that in op-by-op we call a function in eager mode. def f_tf_non_compileable(x): return tf.cond(x[0], lambda: x[1:], lambda: x) f_jax = jax2tf.call_tf(f_tf_non_compileable) x = np.array([True, False], dtype=np.bool_) self.assertAllClose(f_tf_non_compileable(x), f_jax(x)) with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): jax.jit(f_jax)(x)
def test_with_var_read(self, with_jit=True): if jtu.device_under_test() == "gpu": raise unittest.SkipTest("Test fails on GPU") outer_var = tf.Variable(3., dtype=np.float32) def fun_tf(x): return x * outer_var + 1. x = np.float32(2.) res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x) self.assertAllClose(x * 3. + 1., res, check_dtypes=False)
def test_round_trip_saved_model_no_gradients(self): # Save without gradients f_jax = jnp.sum x = np.array([0.7, 0.8], dtype=np.float32) f_tf = tf_test_util.SaveAndLoadFunction( jax2tf.convert(f_jax, with_gradient=True), [tf.TensorSpec(x.shape, dtype=x.dtype)], save_gradients=False) f_rt = jax2tf.call_tf(f_tf) self.assertAllClose(f_jax(x), f_rt(x))
def test_with_var_write_error(self, with_jit=True): if with_jit: raise unittest.SkipTest("variable writes not yet working") outer_var = tf.Variable(3., dtype=np.float32) def fun_tf(x): outer_var.assign(tf.constant(4.)) return x * outer_var + 1. x = np.float32(2.) res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x) self.assertAllClose(x * 4. + 1, res, check_dtypes=False)
def test_with_multiple_capture(self, with_jit=True): v2 = tf.Variable(2., dtype=np.float32) v3 = tf.Variable(3., dtype=np.float32) t4 = tf.constant(4., dtype=np.float32) t5 = tf.constant(5., dtype=np.float32) def fun_tf(x): return (x * v2 + t4) * v3 + t5 x = np.float32(2.) res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x) self.assertAllClose((x * 2. + 4.) * 3. + 5., res, check_dtypes=False)
def test_with_var_read_x64(self, with_jit=True): if jtu.device_under_test() == "gpu": raise unittest.SkipTest("Test fails on GPU") outer_var_array = np.array([3., 4.], dtype=np.float64) outer_var = tf.Variable(outer_var_array) def fun_tf(x): return x * tf.cast(outer_var, x.dtype) + 1. x = np.array([2., 5.,], dtype=np.float32) res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x) self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False)
def test_shape_polymorphism_error(self): x = np.array([.7, .8], dtype=np.float32) def fun_tf(x): return tf.math.sin(x) fun_jax = jax2tf.call_tf(fun_tf) fun_tf_rt = jax2tf.convert(fun_jax, polymorphic_shapes=["b, ..."]) with self.assertRaisesRegex( ValueError, "call_tf cannot be applied to shape-polymorphic arguments"): fun_tf_rt(x)
def test_with_multiple_capture(self, with_jit=True): if jtu.device_under_test() == "gpu": raise unittest.SkipTest("Test fails on GPU") v2 = tf.Variable(2., dtype=np.float32) v3 = tf.Variable(3., dtype=np.float32) t4 = tf.constant(4., dtype=np.float32) t5 = tf.constant(5., dtype=np.float32) def fun_tf(x): return (x * v2 + t4) * v3 + t5 x = np.float32(2.) res = _maybe_jit(with_jit, jax2tf.call_tf(fun_tf))(x) self.assertAllClose((x * 2. + 4.) * 3. + 5., res, check_dtypes=False)
def test_grad_custom(self, with_jit=False): @tf.custom_gradient def func_square_tf(x): # Like x ** 2, but with custom grad 3. * x def grad(dy, variables=None): # dy, = dys return 3. * x * dy, return x * x, grad x = np.float32(4.) grad_x = _maybe_jit(with_jit, jax.grad(jax2tf.call_tf(func_square_tf)))(x) self.assertAllClose(np.float32(3.) * x, grad_x)
def test_saved_model_variables(self): param = np.array([1., 2.], dtype=np.float32) x = np.array([0.7, 0.8], dtype=np.float32) def f_jax(param, x): return jnp.sin(x) + jnp.cos(param) param_v = tf.Variable(param) f_tf = jax2tf.convert(f_jax) _, restored_model = tf_test_util.SaveAndLoadFunction( lambda x: f_tf(param_v, x), input_args=[x], variables=[param_v]) restored_jax = jax2tf.call_tf(restored_model.f) self.assertAllClose(f_jax(param, x), restored_jax(x)) self.assertAllClose(f_jax(param, x), jax.jit(restored_jax)(x))
def test_without_gradient_saved_model(self): # Explicitly with_gradient=False f_jax = jnp.sum x = np.array([0.7, 0.8], dtype=np.float32) f_tf, _ = tf_test_util.SaveAndLoadFunction( jax2tf.convert(f_jax, with_gradient=False), input_args=[x]) f_rt = jax2tf.call_tf(f_tf) self.assertAllClose(f_jax(x), f_rt(x)) with self.assertRaisesRegex(Exception, "Gradient explicitly disabled.*jax2tf-converted function does not support gradients. Use `with_gradient` parameter to enable gradients"): jax.grad(f_rt)(x)
def test_function_dynamic_shape(self): # Call a function for which shape inference does not give an output # shape. x = np.array([-1, 0, 1], dtype=np.int32) def fun_tf(x): # x:i32[3] # The shape depends on the value of x return tf.cond(x[0] >= 0, lambda: x, lambda: x[1:]) # Call in eager mode. Should work! res1 = jax2tf.call_tf(fun_tf)(x) expected = x[1:] self.assertAllClose(expected, res1, check_dtypes=False) # Now under jit, should fail because the function is not compileable with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): fun_jax = jax.jit(jax2tf.call_tf(fun_tf)) fun_jax(x) # TODO(necula): this should work in op-by-op mode, but it fails because # jax2tf.convert does abstract evaluation. with self.assertRaisesRegex(ValueError, _call_tf_dynamic_shape_error): fun_tf_rt = jax2tf.convert(jax2tf.call_tf(fun_tf)) fun_tf_rt(x)
def test_eval_non_compileable_strings(self): # Check that in op-by-op we call a function in eager mode. def f_tf_non_compileable(x): return tf.strings.length(tf.strings.format("Hello {}!", [x])) f_jax = jax2tf.call_tf(f_tf_non_compileable) x = np.float32(0.7) self.assertAllClose(f_tf_non_compileable(x).numpy(), f_jax(x)) with self.assertRaisesRegex(ValueError, CallTfTest.call_tf_non_compileable): jax.jit(f_jax)(x) with self.assertRaisesRegex(ValueError, CallTfTest.call_tf_non_compileable): lax.cond(True, lambda x: f_jax(x), lambda x: f_jax(x), x)
def test_with_var_different_shape(self): # See https://github.com/google/jax/issues/6050 if jtu.device_under_test() == "gpu": raise unittest.SkipTest("Test fails on GPU") v = tf.Variable((4., 2.), dtype=tf.float32) def tf_func(x): return v + x x = np.float32(123.) tf_out = tf_func(x) jax_func = jax.jit(jax2tf.call_tf(tf_func)) jax_out = jax_func(x) self.assertAllClose(tf_out, jax_out, check_dtypes=False)
def test_grad_int_argument(self): # Similar to https://github.com/google/jax/issues/6975 # state is a pytree that contains an integer and a boolean. # The function returns an integer and a boolean. def f(param, state, x): return param * x, state param = np.array([0.7, 0.9], dtype=np.float32) state = dict(array=np.float32(1.), counter=7, truth=True) x = np.float32(3.) # tf.function is important, without it the bug does not appear f_call_tf = jax2tf.call_tf(f) g_call_tf = jax.grad(lambda *args: jnp.sum(f_call_tf(*args)[0]))(param, state, x) g = jax.grad(lambda *args: jnp.sum(f(*args)[0]))(param, state, x) self.assertAllClose(g_call_tf, g)
def test_saved_model_no_gradients(self): # Save without gradients f_jax = jnp.sum x = np.array([0.7, 0.8], dtype=np.float32) f_tf, _ = tf_test_util.SaveAndLoadFunction( jax2tf.convert(f_jax, with_gradient=True), input_args=[x], save_gradients=False) f_rt = jax2tf.call_tf(f_tf) self.assertAllClose(f_jax(x), f_rt(x)) # TODO: clean this up b/191117111: it should fail with a clear error # The following results in a confusing error: # TypeError: tf.Graph captured an external symbolic tensor. with self.assertRaises(TypeError): _ = jax.grad(f_rt)(x)
def test_repro_193754660(self): # Try to reproduce b/193754660. I can't. # We have to have tf.function(jax2tf.convert(jax2tf.call_tf(f_tf))). # The get_compiler_ir will indeed fail for f_tf. Then we try to use # shape inference for f_tf. # I thought to use a f_tf that uses an op without shape inference, e.g., # tfxla.gather. If we wash it through a saved_model I expect that shape # inference would not work on it. Instead, shape inference works!!! x = np.array([0, 1, 2, 3, 4, 5], dtype=np.int32) def f_jax(x): return x[1] f_tf = jax2tf.convert(f_jax) f_tf_rt, _ = tf_test_util.SaveAndLoadFunction(f_tf, input_args=[x]) f_jax2 = jax2tf.call_tf(f_tf_rt) f_tf2 = jax2tf.convert(f_jax2) res = tf.function(f_tf2, autograph=False)(x) self.assertAllClose(res.numpy(), f_jax(x))
def test_custom_grad(self): @jax.custom_vjp def f(x): return x * x # f_fwd: a -> (b, residual) def f_fwd(x): return f(x), np.float32(3.) * x # f_bwd: (residual, CT b) -> [CT a] def f_bwd(residual, ct_b): return residual * ct_b, f.defvjp(f_fwd, f_bwd) f_rt = jax2tf.call_tf(jax2tf.convert(f, with_gradient=True)) x = np.float32(0.7) self.assertAllClose(f(x), f_rt(x)) self.assertAllClose(jax.grad(f)(x), jax.grad(f_rt)(x))
def test_eval_non_compileable_dynamic_shape(self): # Check that in op-by-op we call a function in eager mode. def f_tf_non_compileable(x): return tf.where(x) f_jax = jax2tf.call_tf(f_tf_non_compileable) x = np.array([True, False], dtype=np.bool_) self.assertAllClose(f_tf_non_compileable(x).numpy(), f_jax(x)) if jtu.device_under_test() == "tpu": # TODO: This works on TPU!!! self.assertAllClose( f_tf_non_compileable(x).numpy(), jax.jit(f_jax)(x)) else: with self.assertRaisesRegex(ValueError, CallTfTest.call_tf_non_compileable): jax.jit(f_jax)(x)
def test_several_round_trips(self, f2_function=False, f2_saved_model=False, f4_function=False, f4_saved_model=False): x = np.array(.7, dtype=np.float32) # f(n)(x) = 2. * x^n def f(n): def fn(x): acc = np.array(2., dtype=x.dtype) for i in range(n): acc *= x return acc return fn f2_tf = lambda x: x * jax2tf.convert(f(1))(x) if f2_function: f2_tf = tf.function(f2_tf, autograph=False) if f2_saved_model: f2_tf, _ = tf_test_util.SaveAndLoadFunction(f2_tf, input_args=[x]) self.assertAllClose(f(2)(x), f2_tf(x).numpy()) _, (g_f2_ft, ) = tf_test_util.ComputeTfValueAndGrad(f2_tf, [x]) self.assertAllClose(jax.grad(f(2))(x), g_f2_ft.numpy()) f3_jax = lambda x: x * jax2tf.call_tf(f2_tf)(x) self.assertAllClose(f(3)(x), f3_jax(x)) self.assertAllClose(f(3)(x), jax.jit(f3_jax)(x)) self.assertAllClose(jax.grad(f(3))(x), jax.grad(f3_jax)(x)) f4_tf = lambda x: x * jax2tf.convert(f3_jax)(x) self.assertAllClose(f(4)(x), f4_tf(x).numpy()) _, (g_f4_ft, ) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x]) self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy()) if f4_function: f4_tf = tf.function(f4_tf, autograph=False) if f4_saved_model: f4_tf, _ = tf_test_util.SaveAndLoadFunction(f4_tf, input_args=[x]) self.assertAllClose(f(4)(x), f4_tf(x).numpy()) _, (g_f4_ft, ) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x]) self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy())