def test_several_round_trips(self, f2_function=False, f2_saved_model=False, f4_function=False, f4_saved_model=False): x = np.array(.7, dtype=np.float32) # f(n)(x) = 2. * x^n def f(n): def fn(x): acc = np.array(2., dtype=x.dtype) for i in range(n): acc *= x return acc return fn f2_tf = lambda x: x * jax2tf.convert(f(1))(x) if f2_function: f2_tf = tf.function(f2_tf, autograph=False) if f2_saved_model: f2_tf, _ = tf_test_util.SaveAndLoadFunction(f2_tf, input_args=[x]) self.assertAllClose(f(2)(x), f2_tf(x).numpy()) _, (g_f2_ft, ) = tf_test_util.ComputeTfValueAndGrad(f2_tf, [x]) self.assertAllClose(jax.grad(f(2))(x), g_f2_ft.numpy()) f3_jax = lambda x: x * jax2tf.call_tf(f2_tf)(x) self.assertAllClose(f(3)(x), f3_jax(x)) self.assertAllClose(f(3)(x), jax.jit(f3_jax)(x)) self.assertAllClose(jax.grad(f(3))(x), jax.grad(f3_jax)(x)) f4_tf = lambda x: x * jax2tf.convert(f3_jax)(x) self.assertAllClose(f(4)(x), f4_tf(x).numpy()) _, (g_f4_ft, ) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x]) self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy()) if f4_function: f4_tf = tf.function(f4_tf, autograph=False) if f4_saved_model: f4_tf, _ = tf_test_util.SaveAndLoadFunction(f4_tf, input_args=[x]) self.assertAllClose(f(4)(x), f4_tf(x).numpy()) _, (g_f4_ft, ) = tf_test_util.ComputeTfValueAndGrad(f4_tf, [x]) self.assertAllClose(jax.grad(f(4))(x), g_f4_ft.numpy())
def test_alternate(self): # Alternate sin/cos with sin in TF and cos in JAX f_tf_inner = tf.math.sin def f_jax(x_jax): y_jax = jnp.cos(x_jax) z_jax = jax2tf.call_tf(f_tf_inner)(y_jax) return jnp.cos(z_jax) def f_tf_outer(x_tf): y_tf = tf.math.sin(x_tf) z_tf = jax2tf.convert(f_jax)(y_tf) return tf.math.sin(z_tf) x = np.float32(0.7) self.assertAllClose(np.sin(np.cos(np.sin(np.cos(np.sin(x))))), f_tf_outer(x).numpy()) xv = tf.Variable(x) with tf.GradientTape() as tape: res = f_tf_outer(xv) g_tf = tape.gradient(res, xv) _, gf = tf_test_util.ComputeTfValueAndGrad(f_tf_outer, (x, )) # Eager expected_res = np.sin(np.cos(np.sin(np.cos(np.sin(x))))) self.assertAllClose(expected_res, f_tf_outer(x).numpy()) # Gradient expected_grad = (np.cos(np.cos(np.sin(np.cos(np.sin(x))))) * np.sin(np.sin(np.cos(np.sin(x)))) * np.cos(np.cos(np.sin(x))) * np.sin(np.sin(x)) * np.cos(x)) self.assertAllClose(expected_grad, g_tf.numpy()) # Graph self.assertAllClose( expected_res, tf.function(f_tf_outer, autograph=False)(x).numpy()) # Compiled self.assertAllClose( expected_res, tf.function(f_tf_outer, autograph=False, jit_compile=True)(x).numpy())
def test_gradients_int_argument(self, with_function=True): # https://github.com/google/jax/issues/6975 # Also issue #6975. # An expanded version of test_gradients_unused_argument state = dict( float_used=np.array([0.7, 0.9], dtype=np.float32), float_passthrough=np.float16(1.), float_unused=np.array([1.1, 2.2, 3.3], dtype=np.float32), int_used=np.int16(5), int_passthrough=np.int8(7), int_unused=np.array([1, 2, 3], dtype=np.uint32), bool_used=np.array([True, False, False, True], dtype=np.bool_), bool_passthrough=np.array([True, False, False, True, False], dtype=np.bool_), bool_unused=np.array([[True, False], [False, True]], dtype=np.bool_), ) def jax_f(state): res = dict(state, float_used=2. * state["float_used"], int_used=3 * state["int_used"], bool_used=(state["bool_used"] == state["bool_used"])) del res["float_unused"] del res["int_unused"] del res["bool_unused"] return res args = (state,) res_jax = jax_f(*args) # Native JAX AD vjp_jax_fun, args_vjp = tf_test_util.TransformJaxVJP(jax_f, args, res_jax) grad_jax, = vjp_jax_fun(*args_vjp) def compare_with_overrides(*, what, expected, **expected_overrides): what_keys = set(what.keys()) expected_keys = set(expected.keys()) self.assertEqual(what_keys, expected_keys) for k, w in what.items(): e = expected[k] if k in expected_overrides: if expected_overrides[k] == "ZERO": e = np.zeros_like(w) elif expected_overrides[k] == "ZERO_INT32": e = np.zeros(np.shape(w), dtype=np.int32) elif expected_overrides[k] == "ONE": e = np.ones_like(w) else: e = expected_overrides[k] if e is None: self.assertIsNone(w, msg=k) else: self.assertIsNotNone(w, msg=k) w = w.numpy() if isinstance(w, tf.Tensor) else e e = e.numpy() if isinstance(e, tf.Tensor) else e try: self.assertAllClose(e, w, err_msg=k) except: print(f"Failed at {k}") raise # compare_with_overrides(g_jax, {}, # bool_passthrough=np.zeros(state["bool_passthrough"].shape, dtype=dtypes.float0), # bool_unused=np.zeros(state["bool_unused"].shape, dtype=dtypes.float0), # bool_used=np.zeros(state["bool_used"].shape, dtype=dtypes.float0), # float_passthrough=np.ones_like(state["float_passthrough"]), # float_unused=np.zeros_like(state["float_unused"]), # float_used=np.ones_like(state["float_used"]) * np.array(2., dtype=state["float_used"].dtype), # int_passthrough=np.zeros(state["int_passthrough"].shape, dtype=dtypes.float0), # int_unused=np.zeros(state["int_unused"].shape, dtype=dtypes.float0), # int_used=np.zeros(state["int_used"].shape, dtype=dtypes.float0)) # Now native TF gradients, only to test how native TF AD works _, (grad_tf_0,) = tf_test_util.ComputeTfValueAndGrad( jax_f, args, unconnected_gradients=tf.UnconnectedGradients.ZERO) compare_with_overrides(what=grad_tf_0, expected=grad_jax, float_unused="ZERO", bool_used="ZERO", bool_passthrough="ONE", bool_unused="ZERO", int_used="ZERO", int_passthrough="ONE", int_unused="ZERO") _, (grad_tf_None,) = tf_test_util.ComputeTfValueAndGrad( jax_f, args, unconnected_gradients=tf.UnconnectedGradients.NONE) compare_with_overrides(what=grad_tf_None, expected=grad_tf_0, float_unused=None, int_used=None, int_unused=None, bool_used=None, bool_unused=None) f_tf_jax = jax2tf.convert(jax_f) if with_function: f_tf_jax = tf.function(f_tf_jax, autograph=False) _, (grad_tf_jax_0,) = tf_test_util.ComputeTfValueAndGrad(f_tf_jax, args) # Same results as TF native AD with tf.UnconnectedGradients.ZERO compare_with_overrides(what=grad_tf_jax_0, expected=grad_tf_0, int_passthrough="ZERO", bool_passthrough="ZERO") _, (grad_tf_jax_None,) = tf_test_util.ComputeTfValueAndGrad( f_tf_jax, args, unconnected_gradients=tf.UnconnectedGradients.NONE) compare_with_overrides(what=grad_tf_jax_None, expected=grad_tf_0, int_used=None, int_passthrough=None, int_unused=None, bool_unused=None, bool_used=None, bool_passthrough=None) # Not convert the JAX gradient function tf_vjp_jax_fun = jax2tf.convert(vjp_jax_fun) grad_tf_vjp_jax, = tf_vjp_jax_fun(*args_vjp) compare_with_overrides(what=grad_tf_vjp_jax, expected=grad_tf_0, bool_passthrough="ZERO_INT32", bool_unused="ZERO_INT32", bool_used="ZERO_INT32", int_passthrough="ZERO_INT32", int_unused="ZERO_INT32", int_used="ZERO_INT32")