def remat_partial_eval(trace, *tracers, jaxpr, **params): assert not jaxpr.constvars policy = params['policy'] or (lambda *_, **__: False) # unzip into jaxpr_known and jaxpr_unknown in_unknowns = [not t.is_known() for t in tracers] jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \ pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy) jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars)) _, used_outs_unknown = partition_list(out_inst, out_unknowns) jaxpr_unknown, in_used_unknown = pe.dce_jaxpr(jaxpr_unknown, used_outs_unknown) # compute known outputs and residuals (hoisted out of remat primitive) _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known()) _, in_consts = partition_list(in_used_known, in_consts_) out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts) out_consts_ = iter(out_consts) # form known outputs and collect residual tracers out_known_tracers = [ pe.JaxprTracer(trace, pe.PartialVal.known(next(out_consts_)), None) for uk in out_unknowns if not uk] residuals = list(out_consts_) # set up unknown outputs with a recipe to call remat res_tracers = map(trace.new_instantiated_const, residuals) in_jaxpr_tracers = [*res_tracers, *map(trace.instantiate_const, tracers)] _, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers) out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None) for x in jaxpr_unknown.outvars] new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True) recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p, new_params, source_info_util.current()) for t in out_jaxpr_tracers: t.recipe = recipe # zip together known and unknown outputs return pe._zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
def _xla_call_partial_eval_custom_params_updater( unks_in: Sequence[bool], inst_in: Sequence[bool], kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool], num_res: int, params_known: dict, params_staged: dict) -> Tuple[dict, dict]: # pruned inputs to jaxpr_known according to unks_in, so prune donated_invars donated_known, _ = partition_list(unks_in, params_known['donated_invars']) new_params_known = dict(params_known, donated_invars=tuple(donated_known)) # added num_res new inputs to jaxpr_staged, so extend donated_invars _, donated_staged_ = partition_list(inst_in, params_staged['donated_invars']) donated_staged = [False] * num_res + donated_staged_ new_params_staged = dict(params_staged, donated_invars=tuple(donated_staged)) return new_params_known, new_params_staged
def remat_partial_eval(trace, *tracers, jaxpr, **params): assert not jaxpr.constvars policy = params['policy'] or nothing_saveable in_unknowns = [not t.is_known() for t in tracers] jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res = \ pe.partial_eval_jaxpr_custom( jaxpr, in_unknowns, [True] * len(in_unknowns), False, False, policy) # DCE jaxpr_staged, keeping only instantiated outputs which are unknown _, out_inst_unknown = partition_list(out_inst, out_unknowns) jaxpr_unknown, in_used_staged = pe.dce_jaxpr(jaxpr_staged, out_inst_unknown) used_res, in_used_staged = split_list(in_used_staged, [num_res]) # DCE jaxpr_known, keeping all known outputs but discarding dce'd res out_used_known = [True ] * (len(out_unknowns) - sum(out_unknowns)) + used_res jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known) num_res = sum(used_res) # compute known outputs and residuals (hoisted out of remat primitive) _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known()) _, in_consts = partition_list(in_used_known, in_consts_) out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts) out_knowns, residuals = split_list(out_consts, [len(out_consts) - num_res]) # set up unknown outputs with a recipe to call remat res_tracers = map(trace.new_instantiated_const, residuals) _, tracers_staged = partition_list(in_used_staged, tracers) in_jaxpr_tracers = res_tracers + map(trace.instantiate_const, tracers_staged) out_jaxpr_tracers = [ pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None) for x in jaxpr_unknown.outvars ] new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True) recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p, new_params, jaxpr_unknown.effects, source_info_util.current()) for t in out_jaxpr_tracers: t.recipe = recipe # zip together known and unknown outputs return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers)
def tree_flatten(self): flat_locals, locals_tree = tree_util.tree_flatten(self.locals) is_valid = [ isinstance(l, (core.Tracer, jnp.ndarray, np.ndarray)) for l in flat_locals ] invalid_locals, valid_locals = util.partition_list( is_valid, flat_locals) return valid_locals, (is_valid, invalid_locals, locals_tree, self.filename, self.code_context, self.source, self.lineno, self.offset)