def call_tf(func_tf: Callable) -> Callable: """Calls a TensorFlow function from JAX, with support for reverse autodiff. The ``func_tf`` will be called with TensorFlow-compatible arguments ( numpy.ndarray, ``tf.Tensor`` or ``tf.Variable``) or pytrees thereof. The function must return the same type of results. If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`, or :func:`jax.pmap`, or :func:`jax.xmap`, or a control-flow primitive) then ``func_tf`` will be compiled with ``tf.function(func_tf, jit_compile=True)`` and the resulting XLA computation will be embedded in JAX's XLA computation. If ``call_tf`` appears outside a JAX staging context, it will be called inline using TensorFlow eager mode. The ``call_tf`` supports JAX's reverse-mode autodiff, in which case the ``func_tf`` will be differentiated using ``tf.GradientTape``. This means that the gradient will be TensorFlow-accurate, e.g., will respect the custom gradients that may be defined for the code in ``func_tf``. For an example and more details see the `README <https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax>`_. Args: func_tf: a TensorFlow Callable that can take a pytree of TensorFlow arguments. Returns: a JAX callable that can be invoked with JAX pytree arguments, in op-by-op mode or in a staged context. This callable can be used with JAX's reverse-mode autodiff (:func:`jax.grad`). """ @jax.custom_vjp def make_call(*args_jax): """We wrap it all in `make_call` so that we can attach custom VJP.""" def _dtype(x): return (getattr(x, "dtype", None) or np.asarray(x).dtype) args_jax_flat, args_jax_treedef = tree_util.tree_flatten(args_jax) args_tf_sig_flat = [ tf.TensorSpec(np.shape(a_jax), _to_tf_dtype(_dtype(a_jax))) for a_jax in args_jax_flat ] args_tf_sig = tf.nest.map_structure( lambda a_jax: tf.TensorSpec(np.shape(a_jax), _to_tf_dtype(_dtype(a_jax))), args_jax) func_tf_concrete = tf.function(func_tf).get_concrete_function( *args_tf_sig) res_tf_sig_flat, res_treedef = tree_util.tree_flatten( func_tf_concrete.structured_outputs) res_jax_flat = call_tf_p.bind(*args_jax_flat, func_tf=func_tf, args_treedef=args_jax_treedef, args_tf_sig_flat=args_tf_sig_flat, res_treedef=res_treedef, res_tf_sig_flat=res_tf_sig_flat) # TODO(necula): check the expected result signature assert len(res_jax_flat) == len(res_tf_sig_flat) return res_treedef.unflatten(res_jax_flat) # Define the fwd and bwd custom_vjp functions def make_call_vjp_fwd(*args_jax): # Return the primal argument as the residual return make_call(*args_jax), args_jax def make_call_vjp_bwd(residual, ct_res): args_jax = residual # residual is the primal argument def tf_vjp_fun(args, ct_res): """Invoke TF gradient.""" with tf.GradientTape(persistent=True) as tape: tape.watch(args) res = func_tf(*args) tf.nest.assert_same_structure(res, ct_res) # If the result is not a scalar, we must accumulate arguments cotangents. accumulator = None # Same structure as "arg" def acc_ct(res_, ct_res_): dres_darg = tape.gradient( res_, sources=args, unconnected_gradients=tf.UnconnectedGradients.ZERO) tf.nest.assert_same_structure(dres_darg, args) scaled_dres_darg = tf.nest.map_structure( lambda d: d * ct_res_, dres_darg) nonlocal accumulator accumulator = (scaled_dres_darg if accumulator is None else tf.nest.map_structure( lambda x, y: x + y, accumulator, scaled_dres_darg)) tf.nest.map_structure(acc_ct, res, ct_res) return accumulator # Use call_tf to call the VJP function return call_tf(tf_vjp_fun)(args_jax, ct_res) make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd) return util.wraps(func_tf)(make_call)
def call_tf(callable_tf: Callable) -> Callable: """Calls a TensorFlow function from JAX, with support for reverse autodiff. The ``callable_tf`` will be called with TensorFlow-compatible arguments ( numpy.ndarray, ``tf.Tensor`` or ``tf.Variable``) or pytrees thereof. The function must return the same type of results. If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`, or :func:`jax.pmap`, or :func:`jax.xmap`, or a control-flow primitive) then ``callable_tf`` will be compiled with ``tf.function(callable_tf, jit_compile=True)`` and the resulting XLA computation will be embedded in JAX's XLA computation. If ``call_tf`` appears outside a JAX staging context, it will be called inline using TensorFlow eager mode. The ``call_tf`` supports JAX's reverse-mode autodiff, in which case the ``callable_tf`` will be differentiated using ``tf.GradientTape``. This means that the gradient will be TensorFlow-accurate, e.g., will respect the custom gradients that may be defined for the code in ``callable_tf``. For an example and more details see the `README <https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax>`_. Args: callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow arguments. Returns: a JAX callable that can be invoked with JAX pytree arguments, in op-by-op mode or in a staged context. This callable can be used with JAX's reverse-mode autodiff (:func:`jax.grad`). """ @jax.custom_vjp def make_call(*args_jax): """We wrap it all in `make_call` so that we can attach custom VJP.""" args_flat_jax, args_treedef = tree_util.tree_flatten(args_jax) # Canonicalize the arguments; e.g., makes them x32 if JAX is in 32-bit mode def canonical_arg(v): v = v if getattr(v, "dtype", None) else np.asarray(v) dtype = dtypes.canonicalize_dtype(v.dtype) if dtype != v.dtype: v = v.astype(dtype) return v args_flat_jax = tuple(map(canonical_arg, args_flat_jax)) def make_tensorspec(a_jax): a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype) if any(not core.is_constant_dim(d) for d in a_jax.shape): msg = ( "call_tf cannot be applied to shape-polymorphic arguments. " f"Found argument shape: {a_jax.shape}. " "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion." ) raise ValueError(msg) return tf.TensorSpec(a_jax.shape, a_tf_dtype) args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax)) res_treedef = None # We'll store here the result treedef # The function below will be called at least once, either in eager # or in graph mode. def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]: args_tf = args_treedef.unflatten(args_tf_flat) res_tf = callable_tf(*args_tf) nonlocal res_treedef res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf) assert res_treedef is None or res_treedef == res_treedef_now, f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}" res_treedef = res_treedef_now return res_tf_flat # Prepare a tf.function ahead of time, to cache the concrete functions. This # won't be used in op-by-op execution mode. function_flat_tf = tf.function(callable_flat_tf, autograph=False, jit_compile=True) res_jax_flat = call_tf_p.bind( *args_flat_jax, # Carry the actual function such that op-by-op call can call in TF eager mode. callable_flat_tf=callable_flat_tf, function_flat_tf=function_flat_tf, args_flat_sig_tf=args_flat_sig_tf) return res_treedef.unflatten(res_jax_flat) # Define the fwd and bwd custom_vjp functions def make_call_vjp_fwd(*args_jax): # Return the primal arguments as the residual return make_call(*args_jax), args_jax def make_call_vjp_bwd(residual_jax, ct_res_jax): args_jax = residual_jax # residual is the primal argument def tf_vjp_fun(args_tf, ct_res_tf): """Invoke TF gradient.""" # TF does not like us to watch non-float vars def replace_non_float(arg): if np.issubdtype(arg.dtype.as_numpy_dtype, np.inexact): return arg else: # When watched, this will be ignored. When use in results it will # result in a floating 0. gradient, which JAX will ignore (and # replace it with a float0) return tf.zeros((), dtype=tf.float32) watched_args_tf = tf.nest.map_structure(replace_non_float, args_tf) with tf.GradientTape(persistent=True) as tape: tape.watch(watched_args_tf) res = callable_tf(*args_tf) tf.nest.assert_same_structure(res, ct_res_tf) dres_darg = tape.gradient( tf.nest.map_structure(replace_non_float, res), sources=watched_args_tf, output_gradients=ct_res_tf, unconnected_gradients=tf.UnconnectedGradients.ZERO) tf.nest.assert_same_structure(dres_darg, args_tf) return dres_darg # Use call_tf to call the VJP function ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax) # We must make the float0s that JAX expects def fix_float0(arg_jax, ct_arg_jax): arg_dtype = dtypes.result_type(arg_jax) # May be scalar ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype) if ct_arg_dtype != ct_arg_jax.dtype: return ad_util.zeros_like_aval( core.ShapedArray(np.shape(arg_jax), ct_arg_dtype)) return ct_arg_jax ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax) return ct_args_jax_fixed make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd) return util.wraps(callable_tf)(make_call)
def call_tf(func_tf: Callable) -> Callable: """Calls a TensorFlow function from JAX, with support for reverse autodiff. The ``func_tf`` will be called with TensorFlow-compatible arguments ( numpy.ndarray, ``tf.Tensor`` or ``tf.Variable``) or pytrees thereof. The function must return the same type of results. If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`, or :func:`jax.pmap`, or :func:`jax.xmap`, or a control-flow primitive) then ``func_tf`` will be compiled with ``tf.function(func_tf, jit_compile=True)`` and the resulting XLA computation will be embedded in JAX's XLA computation. If ``call_tf`` appears outside a JAX staging context, it will be called inline using TensorFlow eager mode. The ``call_tf`` supports JAX's reverse-mode autodiff, in which case the ``func_tf`` will be differentiated using ``tf.GradientTape``. This means that the gradient will be TensorFlow-accurate, e.g., will respect the custom gradients that may be defined for the code in ``func_tf``. For an example and more details see the `README <https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax>`_. Args: func_tf: a TensorFlow Callable that can take a pytree of TensorFlow arguments. Returns: a JAX callable that can be invoked with JAX pytree arguments, in op-by-op mode or in a staged context. This callable can be used with JAX's reverse-mode autodiff (:func:`jax.grad`). """ @jax.custom_vjp def make_call(*args_jax): """We wrap it all in `make_call` so that we can attach custom VJP.""" args_jax_flat, args_jax_treedef = tree_util.tree_flatten(args_jax) # Canonicalize the arguments; e.g., makes them x32 if JAX is in 32-bit mode def canonical_arg(v): v = v if getattr(v, "dtype", None) else np.asarray(v) dtype = dtypes.canonicalize_dtype(v.dtype) if dtype != v.dtype: v = v.astype(dtype) return v args_jax_flat = tuple(map(canonical_arg, args_jax_flat)) args_tf_sig_flat = [ tf.TensorSpec(a_jax.shape, jax2tf_internal._to_tf_dtype(a_jax.dtype)) for a_jax in args_jax_flat ] args_tf_sig = args_jax_treedef.unflatten(args_tf_sig_flat) # Trace once through the function to get the result shape with jax2tf_internal.inside_call_tf(): func_tf_concrete = tf.function(func_tf).get_concrete_function( *args_tf_sig) res_tf_sig_flat, res_treedef = tree_util.tree_flatten( func_tf_concrete.structured_outputs) # Canonicalize the result signature; e.g., makes them x32 if JAX is in 32-bit mode def res_sig_to_aval(res_sig: tf.TensorSpec) -> core.AbstractValue: return core.ShapedArray( res_sig.shape, jax2tf_internal._to_jax_dtype(res_sig.dtype)) out_avals = tuple(map(res_sig_to_aval, res_tf_sig_flat)) res_jax_flat = call_tf_p.bind( *args_jax_flat, # Carry the actual function such that op-by-op call can call in TF eager mode. func_tf=func_tf, func_tf_concrete=func_tf_concrete, args_treedef=args_jax_treedef, args_tf_sig_flat=args_tf_sig_flat, res_treedef=res_treedef, out_avals=out_avals) # TODO(necula): check the expected result signature assert len(res_jax_flat) == len(out_avals) return res_treedef.unflatten(res_jax_flat) # Define the fwd and bwd custom_vjp functions def make_call_vjp_fwd(*args_jax): # Return the primal arguments as the residual return make_call(*args_jax), args_jax def make_call_vjp_bwd(residual_jax, ct_res_jax): args_jax = residual_jax # residual is the primal argument def tf_vjp_fun(args_tf, ct_res_tf): """Invoke TF gradient.""" # TF does not like us to watch non-float vars def replace_non_float(arg): if np.issubdtype(arg.dtype.as_numpy_dtype, np.inexact): return arg else: # When watched, this will be ignored. When use in results it will # result in a floating 0. gradient, which JAX will ignore (and # replace it with a float0) return tf.zeros((), dtype=tf.float32) watched_args_tf = tf.nest.map_structure(replace_non_float, args_tf) with tf.GradientTape(persistent=True) as tape: tape.watch(watched_args_tf) res = func_tf(*args_tf) tf.nest.assert_same_structure(res, ct_res_tf) dres_darg = tape.gradient( tf.nest.map_structure(replace_non_float, res), sources=watched_args_tf, output_gradients=ct_res_tf, unconnected_gradients=tf.UnconnectedGradients.ZERO) tf.nest.assert_same_structure(dres_darg, args_tf) return dres_darg # Use call_tf to call the VJP function return call_tf(tf_vjp_fun)(args_jax, ct_res_jax) make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd) return util.wraps(func_tf)(make_call)