def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, primitive_name: Optional[str] = None): wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) debug = pe.debug_info(fun, in_tree, False, primitive_name or "<unknown>") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug) return jaxpr, consts, out_tree()
def fun_remat(*args, **kwargs): args_flat, in_tree = tree_flatten((args, kwargs)) flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(fun, in_tree, False, "checkpoint") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) out_flat = remat_p.bind( *consts, *args_flat, jaxpr=pe.convert_constvars_jaxpr(jaxpr), prevent_cse=prevent_cse, differentiated=False, policy=policy) return tree_unflatten(out_tree(), out_flat)
def make_jaxpr(fun: Callable, in_tree: PyTreeDef, in_avals: typing.Tuple[core.AbstractValue], # with DBIdx in them keep_inputs: typing.Tuple[bool] ) -> typing.Tuple[core.Jaxpr, List[Any], PyTreeDef]: flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) debug = pe.debug_info(fun, in_tree, False, "dex_jit") jaxpr_, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug, keep_inputs=keep_inputs) jaxpr = pe.convert_constvars_jaxpr(jaxpr_) consts = [_canonicalize_arg(c) for c in consts] return jaxpr, consts, out_tree()
def __call__(self, *args, **kwargs): assert not kwargs args_flat, in_tree = tree_flatten(args) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, False, "custom_vmap") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) out_flat = custom_vmap_p.bind(*consts, *args_flat, call=closed_call, rule=self.vmap_rule, in_tree=in_tree) return tree_unflatten(out_tree(), out_flat)
def __call__(self, residual_arg, linear_arg): res_arg, lin_arg = residual_arg, linear_arg _, res_tree = tree_flatten(res_arg) _, lin_tree = tree_flatten(lin_arg) args_flat, in_tree = tree_flatten((res_arg, lin_arg)) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] debug = pe.debug_info(self.fun, in_tree, False, "custom_transpose") jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug) assert not len(consts) closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) out_flat = custom_transpose_p.bind(*consts, *args_flat, call=closed_call, rule=self.transpose, lin_tree=lin_tree, res_tree=res_tree, out_tree=out_tree()) return tree_unflatten(out_tree(), out_flat)