def fun_remat(*args, **kwargs): args_flat, in_tree = tree_flatten((args, kwargs)) flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(fun, in_tree, False, "checkpoint") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) out_flat = remat_p.bind( *consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr), prevent_cse=prevent_cse, differentiated=False, policy=policy) return tree_unflatten(out_tree(), out_flat)
def wrapped(*args, **kwargs): if kwargs: raise NotImplementedError( "sharded_jit over kwargs not yet supported") f = lu.wrap_init(fun) if static_argnums: if max(static_argnums) >= len(args): raise ValueError( f"jitted function has static_argnums={static_argnums}" f" but was called with only {len(args)} positional " f"argument{'s' if len(args) > 1 else ''}. " "All static broadcasted arguments must be passed positionally." ) dyn_argnums = [ i for i in range(len(args)) if i not in static_argnums ] f, args = argnums_partial(f, dyn_argnums, args) args_flat, in_tree = tree_flatten((args, kwargs)) in_parts_flat = tuple( flatten_axes("sharded_jit in_parts", in_tree.children()[0], in_parts)) if local_in_parts is not None: local_in_parts_flat = tuple( flatten_axes("sharded_jit local_in_parts", in_tree.children()[0], local_in_parts)) else: local_in_parts_flat = None flat_fun, out_tree = flatten_fun(f, in_tree) # TODO(skye): having a function-typed param in a primitive seems dicey, is # there a better way? out_parts_thunk = HashableFunction(lambda: tuple( flatten_axes("sharded_jit out_parts", out_tree(), out_parts)), closure=out_parts) if local_out_parts: local_out_parts_thunk = HashableFunction(lambda: tuple( flatten_axes("sharded_jit local_out_parts", out_tree(), local_out_parts)), closure=local_out_parts) else: local_out_parts_thunk = HashableFunction(lambda: None, closure=None) out = sharded_call(flat_fun, *args_flat, nparts=nparts, in_parts=in_parts_flat, out_parts_thunk=out_parts_thunk, local_in_parts=local_in_parts_flat, local_out_parts_thunk=local_out_parts_thunk, local_nparts=local_nparts, name=flat_fun.__name__) return tree_unflatten(out_tree(), out)
def wrapper(f_orig): if len(constant_arg_nums) > 0: # Reordering args so the ones to remove are given first # This will allow us to return a function that has completely removed those args # Moreover, we do it here so this reordering will be optimized by the compiler def f(*args): new_args = tuple(args[k] for k in reorder) return f_orig(*new_args) # Create the partial args needed by trace_to_jaxpr def get_arg(a, unknown): if unknown: return tree_flatten( ( tree_map( lambda x: PartialVal.unknown(get_aval(x).at_least_vspace()), a ), {}, ) )[0] else: return PartialVal.known(a) part_args = [] for k, a in enumerate(dark): temp = get_arg(a, k >= num_args_remove) if isinstance(temp, list): part_args += temp else: part_args.append(temp) part_args = tuple(part_args) # Create jaxpr wrap = lu.wrap_init(f) _, in_tree = tree_flatten((dark, {})) wrap_flat, out_tree = flatten_fun(wrap, in_tree) jaxpr, _, const = trace_to_jaxpr(wrap_flat, part_args) # Create new, partially evaluated function if out_tree().num_leaves == 1 and out_tree().num_nodes == 1: # out_tree() is PyTreeDef(*), so just return the value. Since eval_jaxpr returns a list, # this is just value [0] f_removed = lambda *args: eval_jaxpr( jaxpr, const, *tree_flatten((*dark[0:num_args_remove], *args, {}))[0] )[0] else: # Use out_tree() to reshape the args correctly. f_removed = lambda *args: out_tree().unflatten( eval_jaxpr( jaxpr, const, *tree_flatten((*dark[0:num_args_remove], *args, {}))[0] ) ) return f_removed else: return f_orig
def make_jaxpr(fun: Callable, in_tree: PyTreeDef, in_avals: typing.Tuple[core.AbstractValue], # with DBIdx in them keep_inputs: typing.Tuple[bool] ) -> typing.Tuple[core.Jaxpr, List[Any], PyTreeDef]: flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) debug = pe.debug_info(fun, in_tree, False, "dex_jit") jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug, keep_inputs=keep_inputs) jaxpr = pe.convert_constvars_jaxpr(jaxpr_) consts = [_canonicalize_arg(c) for c in consts] return jaxpr, consts, out_tree()
def linearize(traceable, *primals, **kwargs): has_aux = kwargs.pop('has_aux', False) if not has_aux: jvpfun = jvp(traceable) else: jvpfun, aux = jvp(traceable, has_aux=True) in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace()) for p in primals)) _, in_tree = tree_flatten(((primals, primals), {})) jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree) jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals) out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals) assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals) out_primals_consts = [pval.get_known() for pval in out_primals_pvals] if not has_aux: return out_primals_consts, out_tangents_pvals, jaxpr, consts else: return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()