def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool, has_output_token: bool) -> core.Jaxpr: """Rewrite a Jaxpr to thread the token, if needed.""" assert has_input_token or not has_output_token if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr): return jaxpr mk_new_var = core.gensym([jaxpr]) eqns: List[core.JaxprEqn] = [] last_token_var = mk_new_var(core.abstract_token) # store the incoming token if has_input_token: invars = jaxpr.invars + [last_token_var] else: invars = jaxpr.invars eqns.append( core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var], lax.create_token_p, {}, source_info_util.current())) for eqn in jaxpr.eqns: if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params): eqns.append(eqn) else: output_token_var = mk_new_var(core.abstract_token) _rewrite_eqn(eqn, eqns, last_token_var, output_token_var, mk_new_var) last_token_var = output_token_var outvars = jaxpr.outvars + ([last_token_var] if has_output_token else []) new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns) return new_jaxpr
def _bound_output_tracers(self, primitive, params, jaxpr, consts, env, in_tracers, out_pvs, out_consts, out_keys, name, is_map): """Takes a traced function and binds the Jaxpr to output tracers.""" lifted_jaxpr = pe.convert_constvars_jaxpr(jaxpr) const_tracers = safe_map(self.new_instantiated_const, consts) env_tracers = safe_map(self.instantiate_const, env) out_tracers = [ UnzipTracer(self, pe.PartialVal((pv, const)), None, key) for pv, const, key in safe_zip(out_pvs, out_consts, out_keys) ] new_params = dict(params, name=name, call_jaxpr=lifted_jaxpr) if is_map: new_params = dict(new_params, mapped_invars=tuple([True] * len(const_tracers) + [False] * len(env_tracers) + [True] * len(in_tracers))) if 'donated_invars' in params: new_donated_invars = ( (False, ) * len(const_tracers) + (False, ) * len(env_tracers) + tuple(v for v, t in zip(params['donated_invars'], in_tracers) if not t.pval.is_known())) new_params['donated_invars'] = new_donated_invars eqn = pe.new_eqn_recipe(tuple(const_tracers + env_tracers + in_tracers), out_tracers, primitive, new_params, source_info_util.current()) # pytype: disable=wrong-arg-types for t in out_tracers: t.recipe = eqn return out_tracers
def _axis_index_bind(*, axis_name): dynamic_axis_env = pxla._thread_local_state.dynamic_axis_env frame = dynamic_axis_env[axis_name] trace = frame.pmap_trace out_aval = ShapedArray((), np.int32) out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None) eqn = pe.new_eqn_recipe([], [out_tracer], axis_index_p, dict(axis_name=axis_name), source_info_util.current()) out_tracer.recipe = eqn return out_tracer
def default_process_primitive(self, primitive, tracers, params): """Partially evaluate primitives and saves variable recipes.""" pvs, consts = jax_util.unzip2(t.pval for t in tracers) if all(pv is None for pv in pvs): return primitive.bind(*consts, **params) settings = trace_util.get_dynamic_context(self).settings tracers = safe_map(self.instantiate_const, tracers) if any(not isinstance(t, UnzipTracer) for t in tracers): assert False key = all(t.is_key() for t in tracers) avals = [t.aval for t in tracers] ans = primitive.abstract_eval(*avals, **params) if not primitive.multiple_results: ans = [ans] out_tracers = [ UnzipTracer(self, pe.PartialVal((aval, jax_core.unit)), None, key) for aval in ans ] # Passing in UnzipTracer, which pytype does not recognize as JaxprTracer eqn = pe.new_eqn_recipe(tracers, out_tracers, primitive, params, source_info_util.current()) # pytype: disable=wrong-arg-types for t in out_tracers: t.recipe = eqn is_variable = (key and primitive is harvest.sow_p and params['tag'] == settings.tag) # This block is where UnzipTrace mainly differs from pe.JaxprTrace. Where # JaxprTrace will just return out_tracers, UnzipTrace will record an # additional VariableRecipe into the tracers, which will be used after # the trace is complete to construct init/apply Jaxprs. if is_variable: name, var_in_tracers, var_out_tracers = unzip_registry[primitive]( tracers, out_tracers, **params) variable_recipe = VariableRecipe(name, var_in_tracers, var_out_tracers) for t in out_tracers: t.variable_recipe = variable_recipe if primitive.multiple_results: return out_tracers return out_tracers[0]