Example #1
0
 def checked_fun(*args, **kwargs):
     args_flat, in_tree = tree_flatten((args, kwargs))
     f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
     (err, code, payload,
      out_flat), msgs = checkify_flat(f, errors, *args_flat)
     out = tree_unflatten(out_tree(), out_flat)
     return Error(err, code, msgs, payload), out
Example #2
0
    def converted_fun(*args: TfVal) -> TfVal:
        # This function may take pytrees of TfVals. We can only set
        # tf.custom_gradient on functions that take a flat argument list.
        args_flat, in_tree = tree_util.tree_flatten((args, {}))
        for a in args_flat:
            if not _is_tfvalorunit(a):
                msg = (
                    f"Argument {a} of type {type(a)} of jax2tf.convert(f) should "
                    "be NumPy array, scalar, tf.Variable, or tf.Tensor")
                raise TypeError(msg)

        f = lu.wrap_init(fun)
        # out_tree_thunk() will be the output tree, after running _interpret_fun.
        flat_fun, out_tree_thunk = flatten_fun(f, in_tree)

        # Prepare the grad_fn for tf.custom_gradient.
        def converted_grad_fn(*out_cts_flat: TfVal, **kwargs):
            # TODO(cl/318778369): change **kwargs with variables=None
            variables = kwargs.get("variables", [])
            if variables:
                raise ValueError(
                    "Unexpected variables used in forward pass. "
                    "This should not happen for first-order differentiation. "
                    f"variables={variables}")

            def fun_vjp_jax(args_jax, out_cts_jax):
                # One may think that we can get the pullback while we are converting
                # the main function in the first place. That is problematic, because the
                # pullback may contain captured tracers from the conversion of the
                # main function. Those tracers will confuse the conversion of the
                # pullback. So, we construct the vjp anew.
                _, pullback_jax = jax.vjp(fun, *args_jax)
                return pullback_jax(out_cts_jax)

            out_cts = tree_util.tree_unflatten(out_tree_thunk(), out_cts_flat)
            in_cts = convert(fun_vjp_jax, with_gradient=False)(args, out_cts)
            return in_cts

        if with_gradient:

            @tf.custom_gradient
            def converted_fun_flat_with_custom_gradient(
                    *args_flat: TfVal) -> TfVal:
                return _interpret_fun(flat_fun, args_flat), converted_grad_fn

            out_flat = converted_fun_flat_with_custom_gradient(*args_flat)
        else:
            out_flat_raw = _interpret_fun(flat_fun, args_flat)
            message = (
                "The jax2tf-converted function does not support gradients. "
                "Use `with_gradient` parameter to enable gradients")
            # We use PreventGradient, which is propagated through a SavedModel.
            out_flat = [
                tf.raw_ops.PreventGradient(input=o, message=message)
                for o in out_flat_raw
            ]

        out = tree_util.tree_unflatten(out_tree_thunk(), out_flat)
        return out
Example #3
0
 def doit():
   f = lu.wrap_init(fun)
   args_flat, in_tree = tree_util.tree_flatten((args, {}))
   for a in args_flat:
     if not _is_tfvalorunit(a):
       msg = (f"Argument {a} of type {type(a)} of jax2tf.convert(f) should "
              "be NumPy array, scalar, tf.Variable, or tf.Tensor")
       raise TypeError(msg)
   flat_fun, out_tree = flatten_fun(f, in_tree)
   out_flat = _interpret_fun(flat_fun, args_flat)
   return tree_util.tree_unflatten(out_tree(), out_flat)
Example #4
0
 def jaxpr_const_maker(*args, **kwargs):
     # Set up fun for transformation
     wrapped = lu.wrap_init(fun)
     # Flatten input args
     jax_args, in_tree = tree_util.tree_flatten((args, kwargs))
     # Transform fun to accept flat args and return a flat list result
     jaxtree_fun, out_tree = api_util.flatten_fun(wrapped, in_tree)
     # Abstract and partial-val's flat args
     pvals = safe_map(pv_like, jax_args)
     # Trace function into Jaxpr
     jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals)
     return jaxpr, consts
Example #5
0
 def f_jitted(*args, **kwargs):
   args, in_tree = tree_flatten((args, kwargs))
   f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
   in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
   jaxpr, consts, unconverted_binders = trace_to_jaxpr_dynamic(f, in_avals)
   num_consts = len(consts)
   args = [*consts, *args]
   dim_vals, args = _extract_dim_vals(jaxpr.in_dim_binders, jaxpr.in_binders,
                                      unconverted_binders, args)
   out_flat = dynamic_xla_call_p.bind(*dim_vals, *args, jaxpr=jaxpr,
                                      num_consts=num_consts)
   return tree_unflatten(out_tree(), out_flat)
Example #6
0
 def wrapped_fun(*args: TfValOrUnit) -> TfValOrUnit:
   # TODO(necula): remove the jit disabling once we handle all control-flow.
   # Disabling the jit helps to avoid some unsupported jax primitives.
   # E.g. scan will be statically unrolled.
   f = lu.wrap_init(fun)
   args_flat, in_tree = tree_util.tree_flatten((args, {}))
   for a in args_flat:
     if not _is_tfvalorunit(a):
       msg = (f"Argument {a} of type {type(a)} of jax2tf.convert(f) should "
              "be NumPy array, scalar, tf.Variable, or tf.Tensor")
       raise TypeError(msg)
   flat_fun, out_tree = flatten_fun(f, in_tree)
   out_flat = _interpret_fun(flat_fun, args_flat)
   return tree_util.tree_unflatten(out_tree(), out_flat)
Example #7
0
 def wrapped(*args, **kwargs):
     """Runs a function and binds it to a call primitive."""
     fun = lu.wrap_init(f)
     flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
     flat_fun, out_tree = api_util.flatten_fun(fun, in_tree)
     out_tree_dest = None
     out = prim.bind(flat_fun,
                     *flat_args,
                     num_args=len(flat_args),
                     name=f.__name__,
                     in_tree=in_tree,
                     out_tree=lambda: out_tree_dest,
                     **params)
     out_tree_dest = out_tree()
     return tree_util.tree_unflatten(out_tree_dest, out)
Example #8
0
def _get_jax_objects(function, args, kwargs):
    # Set up function for transformation
    wrapped_function = j_linear_util.wrap_init(function)
    # Flatten input arguments
    jax_arguments, in_tree = j_tree_util.tree_flatten((args, kwargs))
    # Transform function to accept flat arguments
    # and return a flat list result
    jaxtree_function, _ = j_api_util.flatten_fun(wrapped_function, in_tree)
    # Abstract and partial-value's flat arguments
    partial_values = j_util.safe_map(_get_partial_value, jax_arguments)
    # Trace function into Jaxpr
    jaxpr, _, constants = ji_partial_eval.trace_to_jaxpr(
        jaxtree_function, partial_values
    )

    result = (jaxpr, constants)
    return result
Example #9
0
 def doit():
     f = lu.wrap_init(fun)
     args_flat, in_tree = tree_util.tree_flatten((args, {}))
     flat_fun, out_tree = flatten_fun(f, in_tree)
     out_flat = _interpret_fun(flat_fun, args_flat)
     return tree_util.tree_unflatten(out_tree(), out_flat)
Example #10
0
def make_djaxpr(fun, *args, **kwargs):
  args, in_tree = tree_flatten((args, kwargs))
  f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
  in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
  return trace_to_jaxpr_dynamic(f, in_avals)