Ejemplo n.º 1
0
def id_tap(func: Callable, arg, *, result=None, **kwargs):
    """Host-callback tap primitive, like identity function with a call to ``func``.

  **Experimental: please give feedback, and expect changes!**

  ``id_tap`` behaves semantically like the identity function but has the side-effect
  that a user-defined Python function is called with the runtime values of the
  argument.

  Args:
    * arg: the argument passed to the tap function, can be a pytree of JAX
      types.
    * result: if given, specifies the return value of ``id_tap``. By default,
      the return type is ``arg``.
    * kwargs: will be passed directly to the tap function. Can be anything that
      is hashable, these are kept in the host Python process until outfeeds are
      received.

  Returns:
    * ``arg``, or ``result`` if given.

  Tapping works even for code executed on accelerators and even for code under
  JAX transformations. Code that uses taps must be run embedded in
  :func:`outfeed_receiver`.

  For more details see the
  `module documentation <https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_.
  """
    if _OUTFEED_MECHANISM == "none":
        raise NotImplementedError(
            "id_tap works only with jaxlib 0.1.47 and higher")
    if func not in (_end_consumer, _unknown_testing_consumer):
        api._check_callable(func)
    flat_args, arg_treedef = pytree.flatten(arg)
    for arg in flat_args:
        api._check_arg(arg)
    params = dict(kwargs)  #  we pass a copy of params to the primitive
    # See definition of id_tap_p for what parameters it takes
    params["func"] = func
    params["arg_treedef"] = arg_treedef
    if result is not None:
        flat_results, result_treedef = pytree.flatten(result)
        for result in flat_results:
            api._check_arg(result)
        all_args = flat_args + flat_results
        params["nr_untapped"] = len(flat_results)
    else:
        all_args = flat_args
    flat_outs = id_tap_p.bind(*all_args, **params)  # Returns all_args
    if result is not None:
        return result_treedef.unflatten(flat_outs[-params["nr_untapped"]:]
                                        )  # type: ignore[unsupported-operands]
    else:
        return arg_treedef.unflatten(flat_outs)
Ejemplo n.º 2
0
def id_tap(tap_func: Callable, arg, *, result=None, **kwargs):
  """Host-callback tap primitive, like identity function with a call to ``tap_func``.

  **Experimental: please give feedback, and expect changes!**

  ``id_tap`` behaves semantically like the identity function but has the
  side-effect
  that a user-defined Python function is called with the runtime values of the
  argument.

  Args:
    * tap_func: the tap function to call. Must have a signature of the form
      ``tap_func(arg, *, transforms=None, **kwargs)`` where ``arg`` and
      ``kwargs`` are as described below and ``transforms`` is an optional
      sequence describing the applied JAX transformations.
    * arg: the argument passed to the tap function, can be a pytree of JAX
      types.
    * result: if given, specifies the return value of ``id_tap``. This value is
      not passed to the tap function, and in fact is not sent from the device to
      the host. If the ``result`` parameter is not specified then the return
      value of ``id_tap`` is ``arg``.
    * kwargs: will be passed directly to the tap function. Can be anything that
      is hashable, these are kept in the host Python process until outfeeds are
      received.

  Returns:
    * ``arg``, or ``result`` if given.

  The order of execution is by data dependency: after all the arguments and
  the value of ``result`` if present, are computed and before the returned
  value is used. At least one of the returned values of ``id_tap`` must be
  used in the rest of the computation, or else this operation has no effect.

  If you want to tap a constant value, you should use the ``result`` parameter
  to control when it is tapped, otherwise it will be tapped during tracing
  of the function::

    x = id_tap(42, result=x)

  Tapping works even for code executed on accelerators and even for code under
  JAX transformations. Code that uses taps must be run embedded in
  :func:`outfeed_receiver`.

  For more details see the
  `module documentation
  <https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_.
  """
  _initialize_outfeed_receiver()  # Lazy initialization
  api._check_callable(tap_func)
  flat_args, arg_treedef = pytree.flatten(arg)
  for arg in flat_args:
    api._check_arg(arg)
  params = dict(kwargs)  #  we pass a copy of params to the primitive
  # See definition of id_tap_p for what parameters it takes
  params["tap_func_"] = tap_func
  params["arg_treedef_"] = arg_treedef
  params["nr_tapped_args_"] = len(flat_args)
  if result is not None:
    flat_results, result_treedef = pytree.flatten(result)
    for result in flat_results:
      api._check_arg(result)
    all_args = flat_args + flat_results
    nr_results = len(flat_results)
  else:
    all_args = flat_args
    nr_results = 0
  flat_outs = id_tap_p.bind(*all_args, **params)  # Returns all_args
  if result is not None:
    flat_results = flat_outs[-nr_results:]  # type: ignore[unsupported-operands]
    return result_treedef.unflatten(flat_results)
  else:
    return arg_treedef.unflatten(flat_outs)
Ejemplo n.º 3
0
def id_tap(tap_func, arg, *, result=None, **kwargs):
  """Host-callback tap primitive, like identity function with a call to ``tap_func``.

  **Experimental: please give feedback, and expect changes!**

  ``id_tap`` behaves semantically like the identity function but has the
  side-effect that a user-defined Python function is called with the runtime
  value of the argument.

  Args:
    tap_func: tap function to call like ``tap_func(arg, transforms)``, with
      ``arg`` as described below and where ``transforms`` is the sequence of
      applied JAX transformations in the form ``(name, params)``.
    arg: the argument passed to the tap function, can be a pytree of JAX
      types.
    result: if given, specifies the return value of ``id_tap``. This value is
      not passed to the tap function, and in fact is not sent from the device to
      the host. If the ``result`` parameter is not specified then the return
      value of ``id_tap`` is ``arg``.

  Returns:
    ``arg``, or ``result`` if given.

  The order of execution is by data dependency: after all the arguments and
  the value of ``result`` if present, are computed and before the returned
  value is used. At least one of the returned values of ``id_tap`` must be
  used in the rest of the computation, or else this operation has no effect.

  If you want to tap a constant value, you should use the ``result`` parameter
  to control when it is tapped, otherwise it will be tapped during tracing
  of the function::

    x = id_tap(42, result=x)

  Tapping works even for code executed on accelerators and even for code under
  JAX transformations. Code that uses taps must be run embedded in
  :func:`outfeed_receiver`.

  For more details see the
  `module documentation
  <https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_.
  """
  if kwargs:
    warnings.warn(
        "Support for **kwargs in ``id_tap`` is deprecated and will be removed "
        "in the future. Instead, pre-apply keyword arguments, either by using "
        "a closure or by passing ``functools.partial(tap_func, **kwargs)`` "
        "instead.",
        FutureWarning, stacklevel=2)
    tap_func = functools.partial(tap_func, **kwargs)
  _initialize_outfeed_receiver()  # Lazy initialization
  api._check_callable(tap_func)
  flat_args, arg_treedef = pytree.flatten(arg)
  for arg in flat_args:
    api._check_arg(arg)
  # See definition of id_tap_p for what parameters it takes
  params = {}
  params["tap_func_"] = tap_func
  params["arg_treedef_"] = arg_treedef
  params["nr_tapped_args_"] = len(flat_args)
  if result is not None:
    flat_results, result_treedef = pytree.flatten(result)
    for result in flat_results:
      api._check_arg(result)
    all_args = flat_args + flat_results
    nr_results = len(flat_results)
    flat_outs = id_tap_p.bind(*all_args, **params)  # Returns all_args
    flat_results = flat_outs[-nr_results:]  # type: ignore[unsupported-operands]
    return result_treedef.unflatten(flat_results)
  else:
    flat_outs = id_tap_p.bind(*flat_args, **params)
    return arg_treedef.unflatten(flat_outs)