def convert(fun): """Transforms `fun` to be executed by TensorFlow. Args: fun: Function to be transformed. Its arguments and return value should be JAX arrays, or (nested) standard Python containers (tuple/list/dict) thereof. Returns: A version of `fun` that expects TfVals as arguments (or tuple/lists/dicts) thereof, and returns TfVals as outputs. """ api._check_callable(fun) @disable_gradient def wrapped_fun(*args: TfValOrUnit) -> TfValOrUnit: # TODO(necula): remove the jit disabling once we handle all control-flow. # Disabling the jit helps to avoid some unsupported jax primitives. # E.g. scan will be statically unrolled. f = lu.wrap_init(fun) args_flat, in_tree = tree_util.tree_flatten((args, {})) for a in args_flat: if not _is_tfvalorunit(a): msg = ( f"Argument {a} of type {type(a)} of jax2tf.convert(f) should " "be NumPy array, scalar, tf.Variable, or tf.Tensor") raise TypeError(msg) flat_fun, out_tree = flatten_fun(f, in_tree) out_flat = _interpret_fun(flat_fun, args_flat) return tree_util.tree_unflatten(out_tree(), out_flat) return wrapped_fun
def id_tap(func: Callable, arg, *, result=None, **kwargs): """Host-callback tap primitive, like identity function with a call to ``func``. **Experimental: please give feedback, and expect changes!** ``id_tap`` behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime values of the argument. Args: * arg: the argument passed to the tap function, can be a pytree of JAX types. * result: if given, specifies the return value of ``id_tap``. By default, the return type is ``arg``. * kwargs: will be passed directly to the tap function. Can be anything that is hashable, these are kept in the host Python process until outfeeds are received. Returns: * ``arg``, or ``result`` if given. Tapping works even for code executed on accelerators and even for code under JAX transformations. Code that uses taps must be run embedded in :func:`outfeed_receiver`. For more details see the `module documentation <https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_. """ if _OUTFEED_MECHANISM == "none": raise NotImplementedError( "id_tap works only with jaxlib 0.1.47 and higher") if func not in (_end_consumer, _unknown_testing_consumer): api._check_callable(func) flat_args, arg_treedef = pytree.flatten(arg) for arg in flat_args: api._check_arg(arg) params = dict(kwargs) # we pass a copy of params to the primitive # See definition of id_tap_p for what parameters it takes params["func"] = func params["arg_treedef"] = arg_treedef if result is not None: flat_results, result_treedef = pytree.flatten(result) for result in flat_results: api._check_arg(result) all_args = flat_args + flat_results params["nr_untapped"] = len(flat_results) else: all_args = flat_args flat_outs = id_tap_p.bind(*all_args, **params) # Returns all_args if result is not None: return result_treedef.unflatten(flat_outs[-params["nr_untapped"]:] ) # type: ignore[unsupported-operands] else: return arg_treedef.unflatten(flat_outs)
def id_tap(tap_func: Callable, arg, *, result=None, **kwargs): """Host-callback tap primitive, like identity function with a call to ``tap_func``. **Experimental: please give feedback, and expect changes!** ``id_tap`` behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime values of the argument. Args: * tap_func: the tap function to call. Must have a signature of the form ``tap_func(arg, *, transforms=None, **kwargs)`` where ``arg`` and ``kwargs`` are as described below and ``transforms`` is an optional sequence describing the applied JAX transformations. * arg: the argument passed to the tap function, can be a pytree of JAX types. * result: if given, specifies the return value of ``id_tap``. This value is not passed to the tap function, and in fact is not sent from the device to the host. If the ``result`` parameter is not specified then the return value of ``id_tap`` is ``arg``. * kwargs: will be passed directly to the tap function. Can be anything that is hashable, these are kept in the host Python process until outfeeds are received. Returns: * ``arg``, or ``result`` if given. The order of execution is by data dependency: after all the arguments and the value of ``result`` if present, are computed and before the returned value is used. At least one of the returned values of ``id_tap`` must be used in the rest of the computation, or else this operation has no effect. If you want to tap a constant value, you should use the ``result`` parameter to control when it is tapped, otherwise it will be tapped during tracing of the function:: x = id_tap(42, result=x) Tapping works even for code executed on accelerators and even for code under JAX transformations. Code that uses taps must be run embedded in :func:`outfeed_receiver`. For more details see the `module documentation <https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_. """ _initialize_outfeed_receiver() # Lazy initialization api._check_callable(tap_func) flat_args, arg_treedef = pytree.flatten(arg) for arg in flat_args: api._check_arg(arg) params = dict(kwargs) # we pass a copy of params to the primitive # See definition of id_tap_p for what parameters it takes params["tap_func_"] = tap_func params["arg_treedef_"] = arg_treedef params["nr_tapped_args_"] = len(flat_args) if result is not None: flat_results, result_treedef = pytree.flatten(result) for result in flat_results: api._check_arg(result) all_args = flat_args + flat_results nr_results = len(flat_results) else: all_args = flat_args nr_results = 0 flat_outs = id_tap_p.bind(*all_args, **params) # Returns all_args if result is not None: flat_results = flat_outs[-nr_results:] # type: ignore[unsupported-operands] return result_treedef.unflatten(flat_results) else: return arg_treedef.unflatten(flat_outs)
def id_tap(tap_func, arg, *, result=None, **kwargs): """Host-callback tap primitive, like identity function with a call to ``tap_func``. **Experimental: please give feedback, and expect changes!** ``id_tap`` behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime value of the argument. Args: tap_func: tap function to call like ``tap_func(arg, transforms)``, with ``arg`` as described below and where ``transforms`` is the sequence of applied JAX transformations in the form ``(name, params)``. arg: the argument passed to the tap function, can be a pytree of JAX types. result: if given, specifies the return value of ``id_tap``. This value is not passed to the tap function, and in fact is not sent from the device to the host. If the ``result`` parameter is not specified then the return value of ``id_tap`` is ``arg``. Returns: ``arg``, or ``result`` if given. The order of execution is by data dependency: after all the arguments and the value of ``result`` if present, are computed and before the returned value is used. At least one of the returned values of ``id_tap`` must be used in the rest of the computation, or else this operation has no effect. If you want to tap a constant value, you should use the ``result`` parameter to control when it is tapped, otherwise it will be tapped during tracing of the function:: x = id_tap(42, result=x) Tapping works even for code executed on accelerators and even for code under JAX transformations. Code that uses taps must be run embedded in :func:`outfeed_receiver`. For more details see the `module documentation <https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_. """ if kwargs: warnings.warn( "Support for **kwargs in ``id_tap`` is deprecated and will be removed " "in the future. Instead, pre-apply keyword arguments, either by using " "a closure or by passing ``functools.partial(tap_func, **kwargs)`` " "instead.", FutureWarning, stacklevel=2) tap_func = functools.partial(tap_func, **kwargs) _initialize_outfeed_receiver() # Lazy initialization api._check_callable(tap_func) flat_args, arg_treedef = pytree.flatten(arg) for arg in flat_args: api._check_arg(arg) # See definition of id_tap_p for what parameters it takes params = {} params["tap_func_"] = tap_func params["arg_treedef_"] = arg_treedef params["nr_tapped_args_"] = len(flat_args) if result is not None: flat_results, result_treedef = pytree.flatten(result) for result in flat_results: api._check_arg(result) all_args = flat_args + flat_results nr_results = len(flat_results) flat_outs = id_tap_p.bind(*all_args, **params) # Returns all_args flat_results = flat_outs[-nr_results:] # type: ignore[unsupported-operands] return result_treedef.unflatten(flat_results) else: flat_outs = id_tap_p.bind(*flat_args, **params) return arg_treedef.unflatten(flat_outs)
def convert(fun, with_gradient=False): """Transforms `fun` to be executed by TensorFlow. Args: fun: Function to be transformed. Its arguments and return value should be JAX arrays, or (nested) standard Python containers (tuple/list/dict) thereof. with_gradient: if set, will add a tf.custom_gradient to the converted function, by converting the ``jax.vjp(fun)``. Only first-order differentiation is supported for now. If the converted function is saved in a SavedModel, the custom gradients are currently lost and an error will be raised if a gradient computation is attempted. Returns: A version of `fun` that expects TfVals as arguments (or tuple/lists/dicts) thereof, and returns TfVals as outputs. """ api._check_callable(fun) def converted_fun(*args: TfVal) -> TfVal: # This function may take pytrees of TfVals. We can only set # tf.custom_gradient on functions that take a flat argument list. args_flat, in_tree = tree_util.tree_flatten((args, {})) for a in args_flat: if not _is_tfvalorunit(a): msg = ( f"Argument {a} of type {type(a)} of jax2tf.convert(f) should " "be NumPy array, scalar, tf.Variable, or tf.Tensor") raise TypeError(msg) f = lu.wrap_init(fun) # out_tree_thunk() will be the output tree, after running _interpret_fun. flat_fun, out_tree_thunk = flatten_fun(f, in_tree) # Prepare the grad_fn for tf.custom_gradient. def converted_grad_fn(*out_cts_flat: TfVal, **kwargs): # TODO(cl/318778369): change **kwargs with variables=None variables = kwargs.get("variables", []) if variables: raise ValueError( "Unexpected variables used in forward pass. " "This should not happen for first-order differentiation. " f"variables={variables}") def fun_vjp_jax(args_jax, out_cts_jax): # One may think that we can get the pullback while we are converting # the main function in the first place. That is problematic, because the # pullback may contain captured tracers from the conversion of the # main function. Those tracers will confuse the conversion of the # pullback. So, we construct the vjp anew. _, pullback_jax = jax.vjp(fun, *args_jax) return pullback_jax(out_cts_jax) out_cts = tree_util.tree_unflatten(out_tree_thunk(), out_cts_flat) in_cts = convert(fun_vjp_jax, with_gradient=False)(args, out_cts) return in_cts if with_gradient: @tf.custom_gradient def converted_fun_flat_with_custom_gradient( *args_flat: TfVal) -> TfVal: return _interpret_fun(flat_fun, args_flat), converted_grad_fn out_flat = converted_fun_flat_with_custom_gradient(*args_flat) else: out_flat_raw = _interpret_fun(flat_fun, args_flat) message = ( "The jax2tf-converted function does not support gradients. " "Use `with_gradient` parameter to enable gradients") # We use PreventGradient, which is propagated through a SavedModel. out_flat = [ tf.raw_ops.PreventGradient(input=o, message=message) for o in out_flat_raw ] out = tree_util.tree_unflatten(out_tree_thunk(), out_flat) return out return converted_fun