def remat_partial_eval(trace, *tracers, jaxpr, **params): assert not jaxpr.constvars policy = params['policy'] or (lambda *_, **__: False) # unzip into jaxpr_known and jaxpr_unknown in_unknowns = [not t.is_known() for t in tracers] jaxpr_known, jaxpr_unknown, out_unknowns, out_inst, _ = \ pe._partial_eval_jaxpr_custom(jaxpr, in_unknowns, policy) jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, [True] * len(jaxpr_known.outvars)) _, used_outs_unknown = partition_list(out_inst, out_unknowns) jaxpr_unknown, in_used_unknown = pe.dce_jaxpr(jaxpr_unknown, used_outs_unknown) # compute known outputs and residuals (hoisted out of remat primitive) _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known()) _, in_consts = partition_list(in_used_known, in_consts_) out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts) out_consts_ = iter(out_consts) # form known outputs and collect residual tracers out_known_tracers = [ pe.JaxprTracer(trace, pe.PartialVal.known(next(out_consts_)), None) for uk in out_unknowns if not uk] residuals = list(out_consts_) # set up unknown outputs with a recipe to call remat res_tracers = map(trace.new_instantiated_const, residuals) in_jaxpr_tracers = [*res_tracers, *map(trace.instantiate_const, tracers)] _, in_jaxpr_tracers = partition_list(in_used_unknown, in_jaxpr_tracers) out_jaxpr_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None) for x in jaxpr_unknown.outvars] new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True) recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p, new_params, source_info_util.current()) for t in out_jaxpr_tracers: t.recipe = recipe # zip together known and unknown outputs return pe._zip_knowns(out_known_tracers, out_jaxpr_tracers, out_unknowns)
def _bound_output_tracers(self, primitive, params, jaxpr, consts, env, in_tracers, out_pvs, out_consts, out_keys, name, is_map): """Takes a traced function and binds the Jaxpr to output tracers.""" lifted_jaxpr = pe.convert_constvars_jaxpr(jaxpr) const_tracers = safe_map(self.new_instantiated_const, consts) env_tracers = safe_map(self.instantiate_const, env) out_tracers = [ UnzipTracer(self, pe.PartialVal((pv, const)), None, key) for pv, const, key in safe_zip(out_pvs, out_consts, out_keys) ] new_params = dict(params, name=name, call_jaxpr=lifted_jaxpr) if 'donated_invars' in params: new_donated_invars = ( (False, ) * len(const_tracers) + (False, ) * len(env_tracers) + tuple(v for v, t in zip(params['donated_invars'], in_tracers) if not t.pval.is_known())) new_params['donated_invars'] = new_donated_invars if is_map: out_axes = params['out_axes_thunk']() assert all(out_axis == 0 for out_axis in out_axes) new_params['out_axes'] = (0, ) * len(out_tracers) del new_params['out_axes_thunk'] eqn = pe.new_eqn_recipe(tuple(const_tracers + env_tracers + in_tracers), out_tracers, primitive, new_params, source_info_util.current()) # pytype: disable=wrong-arg-types for t in out_tracers: t.recipe = eqn return out_tracers
def custom_rule(*tracers, **params): out_jaxpr_tracers = pe.custom_partial_eval_rules[key](*tracers, **params) out_tracers = [UnzipTracer( out_tracer._trace, out_tracer.pval, out_tracer.recipe, # pylint: disable=protected-access False, None) for out_tracer in out_jaxpr_tracers] for out_tracer in out_tracers: recipe = out_tracer.recipe out_tracer.recipe = pe.new_eqn_recipe(recipe.invars, out_tracers, recipe.primitive, recipe.params, recipe.source_info) # pytype: disable=wrong-arg-types return out_tracers
def _axis_index_bind(*, axis_name): dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env frame = dynamic_axis_env[axis_name] trace = frame.pmap_trace out_aval = ShapedArray((), np.int32) out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, dict(axis_name=axis_name), source_info_util.current()) out_tracer.recipe = eqn return out_tracer
def remat_partial_eval(trace, *tracers, jaxpr, **params): assert not jaxpr.constvars policy = params['policy'] or nothing_saveable in_unknowns = [not t.is_known() for t in tracers] jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res = \ pe.partial_eval_jaxpr_custom( jaxpr, in_unknowns, [True] * len(in_unknowns), False, False, policy) # DCE jaxpr_staged, keeping only instantiated outputs which are unknown _, out_inst_unknown = partition_list(out_inst, out_unknowns) jaxpr_unknown, in_used_staged = pe.dce_jaxpr(jaxpr_staged, out_inst_unknown) used_res, in_used_staged = split_list(in_used_staged, [num_res]) # DCE jaxpr_known, keeping all known outputs but discarding dce'd res out_used_known = [True ] * (len(out_unknowns) - sum(out_unknowns)) + used_res jaxpr_known, in_used_known = pe.dce_jaxpr(jaxpr_known, out_used_known) num_res = sum(used_res) # compute known outputs and residuals (hoisted out of remat primitive) _, in_consts_ = unzip2(t.pval for t in tracers if t.pval.is_known()) _, in_consts = partition_list(in_used_known, in_consts_) out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts) out_knowns, residuals = split_list(out_consts, [len(out_consts) - num_res]) # set up unknown outputs with a recipe to call remat res_tracers = map(trace.new_instantiated_const, residuals) _, tracers_staged = partition_list(in_used_staged, tracers) in_jaxpr_tracers = res_tracers + map(trace.instantiate_const, tracers_staged) out_jaxpr_tracers = [ pe.JaxprTracer(trace, pe.PartialVal.unknown(x.aval), None) for x in jaxpr_unknown.outvars ] new_params = dict(params, jaxpr=jaxpr_unknown, differentiated=True) recipe = pe.new_eqn_recipe(in_jaxpr_tracers, out_jaxpr_tracers, remat_p, new_params, jaxpr_unknown.effects, source_info_util.current()) for t in out_jaxpr_tracers: t.recipe = recipe # zip together known and unknown outputs return merge_lists(out_unknowns, out_knowns, out_jaxpr_tracers)
def default_process_primitive(self, primitive, tracers, params): """Partially evaluate primitives and saves variable recipes.""" pvs, consts = jax_util.unzip2(t.pval for t in tracers) if all(pv is None for pv in pvs): return primitive.bind(*consts, **params) settings = trace_util.get_dynamic_context(self).settings tracers = safe_map(self.instantiate_const, tracers) if any(not isinstance(t, UnzipTracer) for t in tracers): assert False key = all(t.is_key() for t in tracers) avals = [t.aval for t in tracers] ans = primitive.abstract_eval(*avals, **params) if not primitive.multiple_results: ans = [ans] out_tracers = [ UnzipTracer(self, pe.PartialVal((aval, jax_core.unit)), None, key) for aval in ans ] # Passing in UnzipTracer, which pytype does not recognize as JaxprTracer eqn = pe.new_eqn_recipe(tracers, out_tracers, primitive, params, source_info_util.current()) # pytype: disable=wrong-arg-types for t in out_tracers: t.recipe = eqn is_variable = (key and primitive is harvest.sow_p and params['tag'] == settings.tag) # This block is where UnzipTrace mainly differs from pe.JaxprTrace. Where # JaxprTrace will just return out_tracers, UnzipTrace will record an # additional VariableRecipe into the tracers, which will be used after # the trace is complete to construct init/apply Jaxprs. if is_variable: name, var_in_tracers, var_out_tracers = unzip_registry[primitive]( tracers, out_tracers, **params) variable_recipe = VariableRecipe(name, var_in_tracers, var_out_tracers) for t in out_tracers: t.variable_recipe = variable_recipe if primitive.multiple_results: return out_tracers return out_tracers[0]
def _dynamic_xla_call_pe(trace, *tracers, jaxpr, num_consts): in_dim_tracers, tracers = split_list(tracers, [len(jaxpr.in_dim_binders)]) if any(not t.pval.is_known() for t in in_dim_tracers): raise NotImplementedError in_unknowns = [not t.pval.is_known() for t in tracers] jaxpr1, jaxpr2, out_unknowns, num_res = partial_eval_jaxpr(jaxpr, in_unknowns) known_tracers, unknown_tracers = partition_list(in_unknowns, tracers) known_vals = [t.pval.get_known() for t in known_tracers] in_dim_vals = [t.pval.get_known() for t in in_dim_tracers] outs1_res = dynamic_xla_call_p.bind(*in_dim_vals, *known_vals, jaxpr=jaxpr1, num_consts=num_consts) outs1, res = split_list(outs1_res, [len(jaxpr1.outs) - num_res]) in_dim_tracers = map(trace.new_instantiated_const, in_dim_tracers) res_tracers = map(trace.new_instantiated_const, res) outs2 = [pe.JaxprTracer(trace, pe.PartialVal.unknown(v.aval), None) for v in jaxpr2.outs] eqn = pe.new_eqn_recipe(in_dim_tracers + res_tracers + unknown_tracers, outs2, dynamic_xla_call_p, dict(jaxpr=jaxpr2, num_consts=0), None) for t in outs2: t.recipe = eqn outs1, outs2 = iter(outs1), iter(outs2) return [next(outs2) if uk else next(outs1) for uk in out_unknowns]
def _cond_partial_eval(trace, *tracers, branches, linear): in_unknowns = [t.pval[0] is not None for t in tracers] index_uk, *ops_uk = in_unknowns if index_uk: # When the branch index is unknown, we stage out the whole cond. # TODO(mattjj): remove this path when old remat is removed params = dict(branches=branches, linear=linear) return trace.default_process_primitive(cond_p, tracers, params) branches_out_uks = [] for branch_jaxpr in branches: _, _, out_uks, _ = pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=False) branches_out_uks.append(out_uks) out_uks = [any(uks) for uks in zip(*branches_out_uks)] branches_known, branches_unknown, branch_res_avals = [], [], [] for branch_jaxpr in branches: branch_jaxpr_known, branch_jaxpr_unknown, _, res_avals = \ pe.partial_eval_jaxpr_nounits(branch_jaxpr, ops_uk, instantiate=out_uks) branches_known.append(branch_jaxpr_known) branches_unknown.append(branch_jaxpr_unknown) branch_res_avals.append(res_avals) all_res_avals, res_avals_per_branch = _merge_branch_residuals( branch_res_avals) num_res = len(all_res_avals) num_known_outs = len(out_uks) - sum(out_uks) branches_known = _join_cond_outputs(branches_known, all_res_avals, res_avals_per_branch, num_known_outs) branches_unknown = _join_cond_pe_staged_jaxpr_inputs( branches_unknown, all_res_avals, res_avals_per_branch) assert all( all(_map(core.typematch, j.out_avals, branches_known[0].out_avals)) for j in branches_known[1:]) in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()] linear_known = [l for l, uk in zip(linear, ops_uk) if not uk] out_consts_res = cond_p.bind(*in_consts, branches=branches_known, linear=tuple(linear_known)) out_consts, res = split_list(out_consts_res, [len(out_consts_res) - num_res]) index_tracer = trace.instantiate_const(tracers[0]) ops_tracers = [ trace.instantiate_const(t) for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk ] res_tracers = _map(trace.new_instantiated_const, res) out_tracers = [ pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None) for aval in branches_unknown[0].out_avals ] linear_unknown = ([False] * num_res + [l for l, uk in zip(linear, in_unknowns[1:]) if uk]) params = dict(branches=branches_unknown, linear=tuple(linear_unknown)) name_stack = source_info_util.current_name_stack()[len(trace.name_stack):] source = source_info_util.current().replace(name_stack=name_stack) eqn = pe.new_eqn_recipe([index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, core.no_effects, source) for t in out_tracers: t.recipe = eqn return util.merge_lists(out_uks, out_consts, out_tracers)