コード例 #1
0
ファイル: jax2tf.py プロジェクト: girving/jax
def convert(fun):
    """Transforms `fun` to be executed by TensorFlow.

  Args:
    fun: Function to be transformed. Its arguments and return value should be
      JAX arrays, or (nested) standard Python containers (tuple/list/dict)
      thereof.

  Returns:
    A version of `fun` that expects TfVals as arguments (or
    tuple/lists/dicts) thereof, and returns TfVals as outputs.
  """
    api._check_callable(fun)

    @disable_gradient
    def wrapped_fun(*args: TfValOrUnit) -> TfValOrUnit:
        # TODO(necula): remove the jit disabling once we handle all control-flow.
        # Disabling the jit helps to avoid some unsupported jax primitives.
        # E.g. scan will be statically unrolled.
        f = lu.wrap_init(fun)
        args_flat, in_tree = tree_util.tree_flatten((args, {}))
        for a in args_flat:
            if not _is_tfvalorunit(a):
                msg = (
                    f"Argument {a} of type {type(a)} of jax2tf.convert(f) should "
                    "be NumPy array, scalar, tf.Variable, or tf.Tensor")
                raise TypeError(msg)
        flat_fun, out_tree = flatten_fun(f, in_tree)
        out_flat = _interpret_fun(flat_fun, args_flat)
        return tree_util.tree_unflatten(out_tree(), out_flat)

    return wrapped_fun
コード例 #2
0
def id_tap(func: Callable, arg, *, result=None, **kwargs):
    """Host-callback tap primitive, like identity function with a call to ``func``.

  **Experimental: please give feedback, and expect changes!**

  ``id_tap`` behaves semantically like the identity function but has the side-effect
  that a user-defined Python function is called with the runtime values of the
  argument.

  Args:
    * arg: the argument passed to the tap function, can be a pytree of JAX
      types.
    * result: if given, specifies the return value of ``id_tap``. By default,
      the return type is ``arg``.
    * kwargs: will be passed directly to the tap function. Can be anything that
      is hashable, these are kept in the host Python process until outfeeds are
      received.

  Returns:
    * ``arg``, or ``result`` if given.

  Tapping works even for code executed on accelerators and even for code under
  JAX transformations. Code that uses taps must be run embedded in
  :func:`outfeed_receiver`.

  For more details see the
  `module documentation <https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_.
  """
    if _OUTFEED_MECHANISM == "none":
        raise NotImplementedError(
            "id_tap works only with jaxlib 0.1.47 and higher")
    if func not in (_end_consumer, _unknown_testing_consumer):
        api._check_callable(func)
    flat_args, arg_treedef = pytree.flatten(arg)
    for arg in flat_args:
        api._check_arg(arg)
    params = dict(kwargs)  #  we pass a copy of params to the primitive
    # See definition of id_tap_p for what parameters it takes
    params["func"] = func
    params["arg_treedef"] = arg_treedef
    if result is not None:
        flat_results, result_treedef = pytree.flatten(result)
        for result in flat_results:
            api._check_arg(result)
        all_args = flat_args + flat_results
        params["nr_untapped"] = len(flat_results)
    else:
        all_args = flat_args
    flat_outs = id_tap_p.bind(*all_args, **params)  # Returns all_args
    if result is not None:
        return result_treedef.unflatten(flat_outs[-params["nr_untapped"]:]
                                        )  # type: ignore[unsupported-operands]
    else:
        return arg_treedef.unflatten(flat_outs)
コード例 #3
0
ファイル: host_callback.py プロジェクト: zohbahk/jax
def id_tap(tap_func: Callable, arg, *, result=None, **kwargs):
  """Host-callback tap primitive, like identity function with a call to ``tap_func``.

  **Experimental: please give feedback, and expect changes!**

  ``id_tap`` behaves semantically like the identity function but has the
  side-effect
  that a user-defined Python function is called with the runtime values of the
  argument.

  Args:
    * tap_func: the tap function to call. Must have a signature of the form
      ``tap_func(arg, *, transforms=None, **kwargs)`` where ``arg`` and
      ``kwargs`` are as described below and ``transforms`` is an optional
      sequence describing the applied JAX transformations.
    * arg: the argument passed to the tap function, can be a pytree of JAX
      types.
    * result: if given, specifies the return value of ``id_tap``. This value is
      not passed to the tap function, and in fact is not sent from the device to
      the host. If the ``result`` parameter is not specified then the return
      value of ``id_tap`` is ``arg``.
    * kwargs: will be passed directly to the tap function. Can be anything that
      is hashable, these are kept in the host Python process until outfeeds are
      received.

  Returns:
    * ``arg``, or ``result`` if given.

  The order of execution is by data dependency: after all the arguments and
  the value of ``result`` if present, are computed and before the returned
  value is used. At least one of the returned values of ``id_tap`` must be
  used in the rest of the computation, or else this operation has no effect.

  If you want to tap a constant value, you should use the ``result`` parameter
  to control when it is tapped, otherwise it will be tapped during tracing
  of the function::

    x = id_tap(42, result=x)

  Tapping works even for code executed on accelerators and even for code under
  JAX transformations. Code that uses taps must be run embedded in
  :func:`outfeed_receiver`.

  For more details see the
  `module documentation
  <https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_.
  """
  _initialize_outfeed_receiver()  # Lazy initialization
  api._check_callable(tap_func)
  flat_args, arg_treedef = pytree.flatten(arg)
  for arg in flat_args:
    api._check_arg(arg)
  params = dict(kwargs)  #  we pass a copy of params to the primitive
  # See definition of id_tap_p for what parameters it takes
  params["tap_func_"] = tap_func
  params["arg_treedef_"] = arg_treedef
  params["nr_tapped_args_"] = len(flat_args)
  if result is not None:
    flat_results, result_treedef = pytree.flatten(result)
    for result in flat_results:
      api._check_arg(result)
    all_args = flat_args + flat_results
    nr_results = len(flat_results)
  else:
    all_args = flat_args
    nr_results = 0
  flat_outs = id_tap_p.bind(*all_args, **params)  # Returns all_args
  if result is not None:
    flat_results = flat_outs[-nr_results:]  # type: ignore[unsupported-operands]
    return result_treedef.unflatten(flat_results)
  else:
    return arg_treedef.unflatten(flat_outs)
コード例 #4
0
ファイル: host_callback.py プロジェクト: yuejiesong1900/jax
def id_tap(tap_func, arg, *, result=None, **kwargs):
  """Host-callback tap primitive, like identity function with a call to ``tap_func``.

  **Experimental: please give feedback, and expect changes!**

  ``id_tap`` behaves semantically like the identity function but has the
  side-effect that a user-defined Python function is called with the runtime
  value of the argument.

  Args:
    tap_func: tap function to call like ``tap_func(arg, transforms)``, with
      ``arg`` as described below and where ``transforms`` is the sequence of
      applied JAX transformations in the form ``(name, params)``.
    arg: the argument passed to the tap function, can be a pytree of JAX
      types.
    result: if given, specifies the return value of ``id_tap``. This value is
      not passed to the tap function, and in fact is not sent from the device to
      the host. If the ``result`` parameter is not specified then the return
      value of ``id_tap`` is ``arg``.

  Returns:
    ``arg``, or ``result`` if given.

  The order of execution is by data dependency: after all the arguments and
  the value of ``result`` if present, are computed and before the returned
  value is used. At least one of the returned values of ``id_tap`` must be
  used in the rest of the computation, or else this operation has no effect.

  If you want to tap a constant value, you should use the ``result`` parameter
  to control when it is tapped, otherwise it will be tapped during tracing
  of the function::

    x = id_tap(42, result=x)

  Tapping works even for code executed on accelerators and even for code under
  JAX transformations. Code that uses taps must be run embedded in
  :func:`outfeed_receiver`.

  For more details see the
  `module documentation
  <https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html>`_.
  """
  if kwargs:
    warnings.warn(
        "Support for **kwargs in ``id_tap`` is deprecated and will be removed "
        "in the future. Instead, pre-apply keyword arguments, either by using "
        "a closure or by passing ``functools.partial(tap_func, **kwargs)`` "
        "instead.",
        FutureWarning, stacklevel=2)
    tap_func = functools.partial(tap_func, **kwargs)
  _initialize_outfeed_receiver()  # Lazy initialization
  api._check_callable(tap_func)
  flat_args, arg_treedef = pytree.flatten(arg)
  for arg in flat_args:
    api._check_arg(arg)
  # See definition of id_tap_p for what parameters it takes
  params = {}
  params["tap_func_"] = tap_func
  params["arg_treedef_"] = arg_treedef
  params["nr_tapped_args_"] = len(flat_args)
  if result is not None:
    flat_results, result_treedef = pytree.flatten(result)
    for result in flat_results:
      api._check_arg(result)
    all_args = flat_args + flat_results
    nr_results = len(flat_results)
    flat_outs = id_tap_p.bind(*all_args, **params)  # Returns all_args
    flat_results = flat_outs[-nr_results:]  # type: ignore[unsupported-operands]
    return result_treedef.unflatten(flat_results)
  else:
    flat_outs = id_tap_p.bind(*flat_args, **params)
    return arg_treedef.unflatten(flat_outs)
コード例 #5
0
ファイル: jax2tf.py プロジェクト: tachytelicdetonation/jax
def convert(fun, with_gradient=False):
    """Transforms `fun` to be executed by TensorFlow.

  Args:
    fun: Function to be transformed. Its arguments and return value should be
      JAX arrays, or (nested) standard Python containers (tuple/list/dict)
      thereof.
    with_gradient: if set, will add a tf.custom_gradient to the converted
      function, by converting the ``jax.vjp(fun)``. Only first-order
      differentiation is supported for now. If the converted function is
      saved in a SavedModel, the custom gradients are currently lost and
      an error will be raised if a gradient computation is attempted.

  Returns:
    A version of `fun` that expects TfVals as arguments (or
    tuple/lists/dicts) thereof, and returns TfVals as outputs.
  """
    api._check_callable(fun)

    def converted_fun(*args: TfVal) -> TfVal:
        # This function may take pytrees of TfVals. We can only set
        # tf.custom_gradient on functions that take a flat argument list.
        args_flat, in_tree = tree_util.tree_flatten((args, {}))
        for a in args_flat:
            if not _is_tfvalorunit(a):
                msg = (
                    f"Argument {a} of type {type(a)} of jax2tf.convert(f) should "
                    "be NumPy array, scalar, tf.Variable, or tf.Tensor")
                raise TypeError(msg)

        f = lu.wrap_init(fun)
        # out_tree_thunk() will be the output tree, after running _interpret_fun.
        flat_fun, out_tree_thunk = flatten_fun(f, in_tree)

        # Prepare the grad_fn for tf.custom_gradient.
        def converted_grad_fn(*out_cts_flat: TfVal, **kwargs):
            # TODO(cl/318778369): change **kwargs with variables=None
            variables = kwargs.get("variables", [])
            if variables:
                raise ValueError(
                    "Unexpected variables used in forward pass. "
                    "This should not happen for first-order differentiation. "
                    f"variables={variables}")

            def fun_vjp_jax(args_jax, out_cts_jax):
                # One may think that we can get the pullback while we are converting
                # the main function in the first place. That is problematic, because the
                # pullback may contain captured tracers from the conversion of the
                # main function. Those tracers will confuse the conversion of the
                # pullback. So, we construct the vjp anew.
                _, pullback_jax = jax.vjp(fun, *args_jax)
                return pullback_jax(out_cts_jax)

            out_cts = tree_util.tree_unflatten(out_tree_thunk(), out_cts_flat)
            in_cts = convert(fun_vjp_jax, with_gradient=False)(args, out_cts)
            return in_cts

        if with_gradient:

            @tf.custom_gradient
            def converted_fun_flat_with_custom_gradient(
                    *args_flat: TfVal) -> TfVal:
                return _interpret_fun(flat_fun, args_flat), converted_grad_fn

            out_flat = converted_fun_flat_with_custom_gradient(*args_flat)
        else:
            out_flat_raw = _interpret_fun(flat_fun, args_flat)
            message = (
                "The jax2tf-converted function does not support gradients. "
                "Use `with_gradient` parameter to enable gradients")
            # We use PreventGradient, which is propagated through a SavedModel.
            out_flat = [
                tf.raw_ops.PreventGradient(input=o, message=message)
                for o in out_flat_raw
            ]

        out = tree_util.tree_unflatten(out_tree_thunk(), out_flat)
        return out

    return converted_fun