Beispiel #1
0
def call_tf(func_tf: Callable) -> Callable:
    """Calls a TensorFlow function from JAX, with support for reverse autodiff.

  The ``func_tf`` will be called with TensorFlow-compatible arguments (
  numpy.ndarray, ``tf.Tensor`` or ``tf.Variable``) or pytrees thereof. The
  function must return the same type of results.

  If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`,
  or :func:`jax.pmap`, or :func:`jax.xmap`, or a control-flow primitive) then
  ``func_tf`` will be compiled with ``tf.function(func_tf, jit_compile=True)``
  and the resulting XLA computation will be embedded in JAX's XLA computation.

  If ``call_tf`` appears outside a JAX staging context, it will be called inline
  using TensorFlow eager mode.

  The ``call_tf`` supports JAX's reverse-mode autodiff, in which case the
  ``func_tf`` will be differentiated using ``tf.GradientTape``. This means
  that the gradient will be TensorFlow-accurate, e.g., will respect the
  custom gradients that may be defined for the code in ``func_tf``.

  For an example and more details see the
  `README <https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax>`_.

  Args:
    func_tf: a TensorFlow Callable that can take a pytree of TensorFlow
      arguments.
  Returns: a JAX callable that can be invoked with JAX pytree arguments, in
    op-by-op mode or in a staged context. This callable can be used with
    JAX's reverse-mode autodiff (:func:`jax.grad`).
  """
    @jax.custom_vjp
    def make_call(*args_jax):
        """We wrap it all in `make_call` so that we can attach custom VJP."""
        def _dtype(x):
            return (getattr(x, "dtype", None) or np.asarray(x).dtype)

        args_jax_flat, args_jax_treedef = tree_util.tree_flatten(args_jax)
        args_tf_sig_flat = [
            tf.TensorSpec(np.shape(a_jax), _to_tf_dtype(_dtype(a_jax)))
            for a_jax in args_jax_flat
        ]
        args_tf_sig = tf.nest.map_structure(
            lambda a_jax: tf.TensorSpec(np.shape(a_jax),
                                        _to_tf_dtype(_dtype(a_jax))), args_jax)
        func_tf_concrete = tf.function(func_tf).get_concrete_function(
            *args_tf_sig)
        res_tf_sig_flat, res_treedef = tree_util.tree_flatten(
            func_tf_concrete.structured_outputs)

        res_jax_flat = call_tf_p.bind(*args_jax_flat,
                                      func_tf=func_tf,
                                      args_treedef=args_jax_treedef,
                                      args_tf_sig_flat=args_tf_sig_flat,
                                      res_treedef=res_treedef,
                                      res_tf_sig_flat=res_tf_sig_flat)
        # TODO(necula): check the expected result signature
        assert len(res_jax_flat) == len(res_tf_sig_flat)
        return res_treedef.unflatten(res_jax_flat)

    # Define the fwd and bwd custom_vjp functions
    def make_call_vjp_fwd(*args_jax):
        # Return the primal argument as the residual
        return make_call(*args_jax), args_jax

    def make_call_vjp_bwd(residual, ct_res):
        args_jax = residual  # residual is the primal argument

        def tf_vjp_fun(args, ct_res):
            """Invoke TF gradient."""
            with tf.GradientTape(persistent=True) as tape:
                tape.watch(args)
                res = func_tf(*args)

            tf.nest.assert_same_structure(res, ct_res)
            # If the result is not a scalar, we must accumulate arguments cotangents.
            accumulator = None  # Same structure as "arg"

            def acc_ct(res_, ct_res_):
                dres_darg = tape.gradient(
                    res_,
                    sources=args,
                    unconnected_gradients=tf.UnconnectedGradients.ZERO)
                tf.nest.assert_same_structure(dres_darg, args)
                scaled_dres_darg = tf.nest.map_structure(
                    lambda d: d * ct_res_, dres_darg)
                nonlocal accumulator
                accumulator = (scaled_dres_darg if accumulator is None
                               else tf.nest.map_structure(
                                   lambda x, y: x + y, accumulator,
                                   scaled_dres_darg))

            tf.nest.map_structure(acc_ct, res, ct_res)
            return accumulator

        # Use call_tf to call the VJP function
        return call_tf(tf_vjp_fun)(args_jax, ct_res)

    make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
    return util.wraps(func_tf)(make_call)
Beispiel #2
0
def call_tf(callable_tf: Callable) -> Callable:
    """Calls a TensorFlow function from JAX, with support for reverse autodiff.

  The ``callable_tf`` will be called with TensorFlow-compatible arguments (
  numpy.ndarray, ``tf.Tensor`` or ``tf.Variable``) or pytrees thereof. The
  function must return the same type of results.

  If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`,
  or :func:`jax.pmap`, or :func:`jax.xmap`, or a control-flow primitive) then
  ``callable_tf`` will be compiled with ``tf.function(callable_tf, jit_compile=True)``
  and the resulting XLA computation will be embedded in JAX's XLA computation.

  If ``call_tf`` appears outside a JAX staging context, it will be called inline
  using TensorFlow eager mode.

  The ``call_tf`` supports JAX's reverse-mode autodiff, in which case the
  ``callable_tf`` will be differentiated using ``tf.GradientTape``. This means
  that the gradient will be TensorFlow-accurate, e.g., will respect the
  custom gradients that may be defined for the code in ``callable_tf``.

  For an example and more details see the
  `README <https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax>`_.

  Args:
    callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow
      arguments.
  Returns: a JAX callable that can be invoked with JAX pytree arguments, in
    op-by-op mode or in a staged context. This callable can be used with
    JAX's reverse-mode autodiff (:func:`jax.grad`).
  """
    @jax.custom_vjp
    def make_call(*args_jax):
        """We wrap it all in `make_call` so that we can attach custom VJP."""

        args_flat_jax, args_treedef = tree_util.tree_flatten(args_jax)

        # Canonicalize the arguments; e.g., makes them x32 if JAX is in 32-bit mode
        def canonical_arg(v):
            v = v if getattr(v, "dtype", None) else np.asarray(v)
            dtype = dtypes.canonicalize_dtype(v.dtype)
            if dtype != v.dtype:
                v = v.astype(dtype)
            return v

        args_flat_jax = tuple(map(canonical_arg, args_flat_jax))

        def make_tensorspec(a_jax):
            a_tf_dtype = jax2tf_internal._to_tf_dtype(a_jax.dtype)
            if any(not core.is_constant_dim(d) for d in a_jax.shape):
                msg = (
                    "call_tf cannot be applied to shape-polymorphic arguments. "
                    f"Found argument shape: {a_jax.shape}. "
                    "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call-tf for a discussion."
                )
                raise ValueError(msg)

            return tf.TensorSpec(a_jax.shape, a_tf_dtype)

        args_flat_sig_tf = tuple(map(make_tensorspec, args_flat_jax))

        res_treedef = None  # We'll store here the result treedef

        # The function below will be called at least once, either in eager
        # or in graph mode.
        def callable_flat_tf(*args_tf_flat: TfVal) -> Sequence[TfVal]:
            args_tf = args_treedef.unflatten(args_tf_flat)
            res_tf = callable_tf(*args_tf)
            nonlocal res_treedef
            res_tf_flat, res_treedef_now = tree_util.tree_flatten(res_tf)
            assert res_treedef is None or res_treedef == res_treedef_now, f"Subsequent calls had different results. Previous {res_treedef} and now {res_treedef_now}"
            res_treedef = res_treedef_now
            return res_tf_flat

        # Prepare a tf.function ahead of time, to cache the concrete functions. This
        # won't be used in op-by-op execution mode.
        function_flat_tf = tf.function(callable_flat_tf,
                                       autograph=False,
                                       jit_compile=True)

        res_jax_flat = call_tf_p.bind(
            *args_flat_jax,
            # Carry the actual function such that op-by-op call can call in TF eager mode.
            callable_flat_tf=callable_flat_tf,
            function_flat_tf=function_flat_tf,
            args_flat_sig_tf=args_flat_sig_tf)
        return res_treedef.unflatten(res_jax_flat)

    # Define the fwd and bwd custom_vjp functions
    def make_call_vjp_fwd(*args_jax):
        # Return the primal arguments as the residual
        return make_call(*args_jax), args_jax

    def make_call_vjp_bwd(residual_jax, ct_res_jax):
        args_jax = residual_jax  # residual is the primal argument

        def tf_vjp_fun(args_tf, ct_res_tf):
            """Invoke TF gradient."""

            # TF does not like us to watch non-float vars
            def replace_non_float(arg):
                if np.issubdtype(arg.dtype.as_numpy_dtype, np.inexact):
                    return arg
                else:
                    # When watched, this will be ignored. When use in results it will
                    # result in a floating 0. gradient, which JAX will ignore (and
                    # replace it with a float0)
                    return tf.zeros((), dtype=tf.float32)

            watched_args_tf = tf.nest.map_structure(replace_non_float, args_tf)
            with tf.GradientTape(persistent=True) as tape:
                tape.watch(watched_args_tf)
                res = callable_tf(*args_tf)

            tf.nest.assert_same_structure(res, ct_res_tf)
            dres_darg = tape.gradient(
                tf.nest.map_structure(replace_non_float, res),
                sources=watched_args_tf,
                output_gradients=ct_res_tf,
                unconnected_gradients=tf.UnconnectedGradients.ZERO)

            tf.nest.assert_same_structure(dres_darg, args_tf)
            return dres_darg

        # Use call_tf to call the VJP function
        ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax)

        # We must make the float0s that JAX expects
        def fix_float0(arg_jax, ct_arg_jax):
            arg_dtype = dtypes.result_type(arg_jax)  # May be scalar
            ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype)
            if ct_arg_dtype != ct_arg_jax.dtype:
                return ad_util.zeros_like_aval(
                    core.ShapedArray(np.shape(arg_jax), ct_arg_dtype))
            return ct_arg_jax

        ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax,
                                               ct_args_jax)
        return ct_args_jax_fixed

    make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
    return util.wraps(callable_tf)(make_call)
Beispiel #3
0
def call_tf(func_tf: Callable) -> Callable:
    """Calls a TensorFlow function from JAX, with support for reverse autodiff.

  The ``func_tf`` will be called with TensorFlow-compatible arguments (
  numpy.ndarray, ``tf.Tensor`` or ``tf.Variable``) or pytrees thereof. The
  function must return the same type of results.

  If ``call_tf`` appears in a JAX staging context (:func:`jax.jit`,
  or :func:`jax.pmap`, or :func:`jax.xmap`, or a control-flow primitive) then
  ``func_tf`` will be compiled with ``tf.function(func_tf, jit_compile=True)``
  and the resulting XLA computation will be embedded in JAX's XLA computation.

  If ``call_tf`` appears outside a JAX staging context, it will be called inline
  using TensorFlow eager mode.

  The ``call_tf`` supports JAX's reverse-mode autodiff, in which case the
  ``func_tf`` will be differentiated using ``tf.GradientTape``. This means
  that the gradient will be TensorFlow-accurate, e.g., will respect the
  custom gradients that may be defined for the code in ``func_tf``.

  For an example and more details see the
  `README <https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax>`_.

  Args:
    func_tf: a TensorFlow Callable that can take a pytree of TensorFlow
      arguments.
  Returns: a JAX callable that can be invoked with JAX pytree arguments, in
    op-by-op mode or in a staged context. This callable can be used with
    JAX's reverse-mode autodiff (:func:`jax.grad`).
  """
    @jax.custom_vjp
    def make_call(*args_jax):
        """We wrap it all in `make_call` so that we can attach custom VJP."""

        args_jax_flat, args_jax_treedef = tree_util.tree_flatten(args_jax)

        # Canonicalize the arguments; e.g., makes them x32 if JAX is in 32-bit mode
        def canonical_arg(v):
            v = v if getattr(v, "dtype", None) else np.asarray(v)
            dtype = dtypes.canonicalize_dtype(v.dtype)
            if dtype != v.dtype:
                v = v.astype(dtype)
            return v

        args_jax_flat = tuple(map(canonical_arg, args_jax_flat))
        args_tf_sig_flat = [
            tf.TensorSpec(a_jax.shape,
                          jax2tf_internal._to_tf_dtype(a_jax.dtype))
            for a_jax in args_jax_flat
        ]
        args_tf_sig = args_jax_treedef.unflatten(args_tf_sig_flat)

        # Trace once through the function to get the result shape
        with jax2tf_internal.inside_call_tf():
            func_tf_concrete = tf.function(func_tf).get_concrete_function(
                *args_tf_sig)

        res_tf_sig_flat, res_treedef = tree_util.tree_flatten(
            func_tf_concrete.structured_outputs)

        # Canonicalize the result signature; e.g., makes them x32 if JAX is in 32-bit mode
        def res_sig_to_aval(res_sig: tf.TensorSpec) -> core.AbstractValue:
            return core.ShapedArray(
                res_sig.shape, jax2tf_internal._to_jax_dtype(res_sig.dtype))

        out_avals = tuple(map(res_sig_to_aval, res_tf_sig_flat))
        res_jax_flat = call_tf_p.bind(
            *args_jax_flat,
            # Carry the actual function such that op-by-op call can call in TF eager mode.
            func_tf=func_tf,
            func_tf_concrete=func_tf_concrete,
            args_treedef=args_jax_treedef,
            args_tf_sig_flat=args_tf_sig_flat,
            res_treedef=res_treedef,
            out_avals=out_avals)
        # TODO(necula): check the expected result signature
        assert len(res_jax_flat) == len(out_avals)
        return res_treedef.unflatten(res_jax_flat)

    # Define the fwd and bwd custom_vjp functions
    def make_call_vjp_fwd(*args_jax):
        # Return the primal arguments as the residual
        return make_call(*args_jax), args_jax

    def make_call_vjp_bwd(residual_jax, ct_res_jax):
        args_jax = residual_jax  # residual is the primal argument

        def tf_vjp_fun(args_tf, ct_res_tf):
            """Invoke TF gradient."""

            # TF does not like us to watch non-float vars
            def replace_non_float(arg):
                if np.issubdtype(arg.dtype.as_numpy_dtype, np.inexact):
                    return arg
                else:
                    # When watched, this will be ignored. When use in results it will
                    # result in a floating 0. gradient, which JAX will ignore (and
                    # replace it with a float0)
                    return tf.zeros((), dtype=tf.float32)

            watched_args_tf = tf.nest.map_structure(replace_non_float, args_tf)
            with tf.GradientTape(persistent=True) as tape:
                tape.watch(watched_args_tf)
                res = func_tf(*args_tf)

            tf.nest.assert_same_structure(res, ct_res_tf)
            dres_darg = tape.gradient(
                tf.nest.map_structure(replace_non_float, res),
                sources=watched_args_tf,
                output_gradients=ct_res_tf,
                unconnected_gradients=tf.UnconnectedGradients.ZERO)

            tf.nest.assert_same_structure(dres_darg, args_tf)
            return dres_darg

        # Use call_tf to call the VJP function
        return call_tf(tf_vjp_fun)(args_jax, ct_res_jax)

    make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd)
    return util.wraps(func_tf)(make_call)