def id_tap(func: Callable, arg, *, result=None, **kwargs): """Host-callback tap primitive, like identity function with a call to ``func``. **Experimental: please give feedback, and expect changes!** ``id_tap`` behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime values of the argument. Args: * arg: the argument passed to the tap function, can be a pytree of JAX types. * result: if given, specifies the return value of ``id_tap``. By default, the return type is ``arg``. * kwargs: will be passed directly to the tap function. Can be anything that is hashable, these are kept in the host Python process until outfeeds are received. Returns: * ``arg``, or ``result`` if given. Tapping works even for code executed on accelerators and even for code under JAX transformations. Code that uses taps must be run embedded in :func:`outfeed_receiver`. For more details see the `module documentation <https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_. """ if _OUTFEED_MECHANISM == "none": raise NotImplementedError( "id_tap works only with jaxlib 0.1.47 and higher") if func not in (_end_consumer, _unknown_testing_consumer): api._check_callable(func) flat_args, arg_treedef = pytree.flatten(arg) for arg in flat_args: api._check_arg(arg) params = dict(kwargs) # we pass a copy of params to the primitive # See definition of id_tap_p for what parameters it takes params["func"] = func params["arg_treedef"] = arg_treedef if result is not None: flat_results, result_treedef = pytree.flatten(result) for result in flat_results: api._check_arg(result) all_args = flat_args + flat_results params["nr_untapped"] = len(flat_results) else: all_args = flat_args flat_outs = id_tap_p.bind(*all_args, **params) # Returns all_args if result is not None: return result_treedef.unflatten(flat_outs[-params["nr_untapped"]:] ) # type: ignore[unsupported-operands] else: return arg_treedef.unflatten(flat_outs)
def id_tap(tap_func: Callable, arg, *, result=None, **kwargs): """Host-callback tap primitive, like identity function with a call to ``tap_func``. **Experimental: please give feedback, and expect changes!** ``id_tap`` behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime values of the argument. Args: * tap_func: the tap function to call. Must have a signature of the form ``tap_func(arg, *, transforms=None, **kwargs)`` where ``arg`` and ``kwargs`` are as described below and ``transforms`` is an optional sequence describing the applied JAX transformations. * arg: the argument passed to the tap function, can be a pytree of JAX types. * result: if given, specifies the return value of ``id_tap``. This value is not passed to the tap function, and in fact is not sent from the device to the host. If the ``result`` parameter is not specified then the return value of ``id_tap`` is ``arg``. * kwargs: will be passed directly to the tap function. Can be anything that is hashable, these are kept in the host Python process until outfeeds are received. Returns: * ``arg``, or ``result`` if given. The order of execution is by data dependency: after all the arguments and the value of ``result`` if present, are computed and before the returned value is used. At least one of the returned values of ``id_tap`` must be used in the rest of the computation, or else this operation has no effect. If you want to tap a constant value, you should use the ``result`` parameter to control when it is tapped, otherwise it will be tapped during tracing of the function:: x = id_tap(42, result=x) Tapping works even for code executed on accelerators and even for code under JAX transformations. Code that uses taps must be run embedded in :func:`outfeed_receiver`. For more details see the `module documentation <https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_. """ _initialize_outfeed_receiver() # Lazy initialization api._check_callable(tap_func) flat_args, arg_treedef = pytree.flatten(arg) for arg in flat_args: api._check_arg(arg) params = dict(kwargs) # we pass a copy of params to the primitive # See definition of id_tap_p for what parameters it takes params["tap_func_"] = tap_func params["arg_treedef_"] = arg_treedef params["nr_tapped_args_"] = len(flat_args) if result is not None: flat_results, result_treedef = pytree.flatten(result) for result in flat_results: api._check_arg(result) all_args = flat_args + flat_results nr_results = len(flat_results) else: all_args = flat_args nr_results = 0 flat_outs = id_tap_p.bind(*all_args, **params) # Returns all_args if result is not None: flat_results = flat_outs[-nr_results:] # type: ignore[unsupported-operands] return result_treedef.unflatten(flat_results) else: return arg_treedef.unflatten(flat_outs)
def id_tap(tap_func, arg, *, result=None, **kwargs): """Host-callback tap primitive, like identity function with a call to ``tap_func``. **Experimental: please give feedback, and expect changes!** ``id_tap`` behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime value of the argument. Args: tap_func: tap function to call like ``tap_func(arg, transforms)``, with ``arg`` as described below and where ``transforms`` is the sequence of applied JAX transformations in the form ``(name, params)``. arg: the argument passed to the tap function, can be a pytree of JAX types. result: if given, specifies the return value of ``id_tap``. This value is not passed to the tap function, and in fact is not sent from the device to the host. If the ``result`` parameter is not specified then the return value of ``id_tap`` is ``arg``. Returns: ``arg``, or ``result`` if given. The order of execution is by data dependency: after all the arguments and the value of ``result`` if present, are computed and before the returned value is used. At least one of the returned values of ``id_tap`` must be used in the rest of the computation, or else this operation has no effect. If you want to tap a constant value, you should use the ``result`` parameter to control when it is tapped, otherwise it will be tapped during tracing of the function:: x = id_tap(42, result=x) Tapping works even for code executed on accelerators and even for code under JAX transformations. Code that uses taps must be run embedded in :func:`outfeed_receiver`. For more details see the `module documentation <https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_. """ if kwargs: warnings.warn( "Support for **kwargs in ``id_tap`` is deprecated and will be removed " "in the future. Instead, pre-apply keyword arguments, either by using " "a closure or by passing ``functools.partial(tap_func, **kwargs)`` " "instead.", FutureWarning, stacklevel=2) tap_func = functools.partial(tap_func, **kwargs) _initialize_outfeed_receiver() # Lazy initialization api._check_callable(tap_func) flat_args, arg_treedef = pytree.flatten(arg) for arg in flat_args: api._check_arg(arg) # See definition of id_tap_p for what parameters it takes params = {} params["tap_func_"] = tap_func params["arg_treedef_"] = arg_treedef params["nr_tapped_args_"] = len(flat_args) if result is not None: flat_results, result_treedef = pytree.flatten(result) for result in flat_results: api._check_arg(result) all_args = flat_args + flat_results nr_results = len(flat_results) flat_outs = id_tap_p.bind(*all_args, **params) # Returns all_args flat_results = flat_outs[-nr_results:] # type: ignore[unsupported-operands] return result_treedef.unflatten(flat_results) else: flat_outs = id_tap_p.bind(*flat_args, **params) return arg_treedef.unflatten(flat_outs)