def checked_fun(*args, **kwargs): args_flat, in_tree = tree_flatten((args, kwargs)) f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) (err, code, payload, out_flat), msgs = checkify_flat(f, errors, *args_flat) out = tree_unflatten(out_tree(), out_flat) return Error(err, code, msgs, payload), out
def converted_fun(*args: TfVal) -> TfVal: # This function may take pytrees of TfVals. We can only set # tf.custom_gradient on functions that take a flat argument list. args_flat, in_tree = tree_util.tree_flatten((args, {})) for a in args_flat: if not _is_tfvalorunit(a): msg = ( f"Argument {a} of type {type(a)} of jax2tf.convert(f) should " "be NumPy array, scalar, tf.Variable, or tf.Tensor") raise TypeError(msg) f = lu.wrap_init(fun) # out_tree_thunk() will be the output tree, after running _interpret_fun. flat_fun, out_tree_thunk = flatten_fun(f, in_tree) # Prepare the grad_fn for tf.custom_gradient. def converted_grad_fn(*out_cts_flat: TfVal, **kwargs): # TODO(cl/318778369): change **kwargs with variables=None variables = kwargs.get("variables", []) if variables: raise ValueError( "Unexpected variables used in forward pass. " "This should not happen for first-order differentiation. " f"variables={variables}") def fun_vjp_jax(args_jax, out_cts_jax): # One may think that we can get the pullback while we are converting # the main function in the first place. That is problematic, because the # pullback may contain captured tracers from the conversion of the # main function. Those tracers will confuse the conversion of the # pullback. So, we construct the vjp anew. _, pullback_jax = jax.vjp(fun, *args_jax) return pullback_jax(out_cts_jax) out_cts = tree_util.tree_unflatten(out_tree_thunk(), out_cts_flat) in_cts = convert(fun_vjp_jax, with_gradient=False)(args, out_cts) return in_cts if with_gradient: @tf.custom_gradient def converted_fun_flat_with_custom_gradient( *args_flat: TfVal) -> TfVal: return _interpret_fun(flat_fun, args_flat), converted_grad_fn out_flat = converted_fun_flat_with_custom_gradient(*args_flat) else: out_flat_raw = _interpret_fun(flat_fun, args_flat) message = ( "The jax2tf-converted function does not support gradients. " "Use `with_gradient` parameter to enable gradients") # We use PreventGradient, which is propagated through a SavedModel. out_flat = [ tf.raw_ops.PreventGradient(input=o, message=message) for o in out_flat_raw ] out = tree_util.tree_unflatten(out_tree_thunk(), out_flat) return out
def doit(): f = lu.wrap_init(fun) args_flat, in_tree = tree_util.tree_flatten((args, {})) for a in args_flat: if not _is_tfvalorunit(a): msg = (f"Argument {a} of type {type(a)} of jax2tf.convert(f) should " "be NumPy array, scalar, tf.Variable, or tf.Tensor") raise TypeError(msg) flat_fun, out_tree = flatten_fun(f, in_tree) out_flat = _interpret_fun(flat_fun, args_flat) return tree_util.tree_unflatten(out_tree(), out_flat)
def jaxpr_const_maker(*args, **kwargs): # Set up fun for transformation wrapped = lu.wrap_init(fun) # Flatten input args jax_args, in_tree = tree_util.tree_flatten((args, kwargs)) # Transform fun to accept flat args and return a flat list result jaxtree_fun, out_tree = api_util.flatten_fun(wrapped, in_tree) # Abstract and partial-val's flat args pvals = safe_map(pv_like, jax_args) # Trace function into Jaxpr jaxpr, _, consts = pe.trace_to_jaxpr(jaxtree_fun, pvals) return jaxpr, consts
def f_jitted(*args, **kwargs): args, in_tree = tree_flatten((args, kwargs)) f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args] jaxpr, consts, unconverted_binders = trace_to_jaxpr_dynamic(f, in_avals) num_consts = len(consts) args = [*consts, *args] dim_vals, args = _extract_dim_vals(jaxpr.in_dim_binders, jaxpr.in_binders, unconverted_binders, args) out_flat = dynamic_xla_call_p.bind(*dim_vals, *args, jaxpr=jaxpr, num_consts=num_consts) return tree_unflatten(out_tree(), out_flat)
def wrapped_fun(*args: TfValOrUnit) -> TfValOrUnit: # TODO(necula): remove the jit disabling once we handle all control-flow. # Disabling the jit helps to avoid some unsupported jax primitives. # E.g. scan will be statically unrolled. f = lu.wrap_init(fun) args_flat, in_tree = tree_util.tree_flatten((args, {})) for a in args_flat: if not _is_tfvalorunit(a): msg = (f"Argument {a} of type {type(a)} of jax2tf.convert(f) should " "be NumPy array, scalar, tf.Variable, or tf.Tensor") raise TypeError(msg) flat_fun, out_tree = flatten_fun(f, in_tree) out_flat = _interpret_fun(flat_fun, args_flat) return tree_util.tree_unflatten(out_tree(), out_flat)
def wrapped(*args, **kwargs): """Runs a function and binds it to a call primitive.""" fun = lu.wrap_init(f) flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) flat_fun, out_tree = api_util.flatten_fun(fun, in_tree) out_tree_dest = None out = prim.bind(flat_fun, *flat_args, num_args=len(flat_args), name=f.__name__, in_tree=in_tree, out_tree=lambda: out_tree_dest, **params) out_tree_dest = out_tree() return tree_util.tree_unflatten(out_tree_dest, out)
def _get_jax_objects(function, args, kwargs): # Set up function for transformation wrapped_function = j_linear_util.wrap_init(function) # Flatten input arguments jax_arguments, in_tree = j_tree_util.tree_flatten((args, kwargs)) # Transform function to accept flat arguments # and return a flat list result jaxtree_function, _ = j_api_util.flatten_fun(wrapped_function, in_tree) # Abstract and partial-value's flat arguments partial_values = j_util.safe_map(_get_partial_value, jax_arguments) # Trace function into Jaxpr jaxpr, _, constants = ji_partial_eval.trace_to_jaxpr( jaxtree_function, partial_values ) result = (jaxpr, constants) return result
def doit(): f = lu.wrap_init(fun) args_flat, in_tree = tree_util.tree_flatten((args, {})) flat_fun, out_tree = flatten_fun(f, in_tree) out_flat = _interpret_fun(flat_fun, args_flat) return tree_util.tree_unflatten(out_tree(), out_flat)
def make_djaxpr(fun, *args, **kwargs): args, in_tree = tree_flatten((args, kwargs)) f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args] return trace_to_jaxpr_dynamic(f, in_avals)