def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool, has_output_token: bool) -> Tuple[core.Jaxpr, bool]: """Rewrite a Jaxpr to thread the token, if needed.""" assert has_input_token or not has_output_token if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr): return (jaxpr, False) max_var_count = max(_jaxpr_var_defs(jaxpr)) mk_new_id = itertools.count(start=max_var_count + 1) def mk_new_var(aval: core.AbstractValue) -> core.Var: return core.Var(next(mk_new_id), '', aval) eqns: List[core.JaxprEqn] = [] last_token_var = mk_new_var( core.abstract_token) # store the incoming token if has_input_token: invars = jaxpr.invars + [last_token_var] else: invars = jaxpr.invars eqns.append( core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var], lax.create_token_p, {})) for eqn in jaxpr.eqns: if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params): eqns.append(eqn) else: output_token_var = mk_new_var(core.abstract_token) _rewrite_eqn(eqn, eqns, last_token_var, output_token_var) last_token_var = output_token_var outvars = jaxpr.outvars + ([last_token_var] if has_output_token else []) return (core.Jaxpr(jaxpr.constvars, invars, outvars, eqns), True)
def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool, has_output_token: bool) -> core.Jaxpr: """Rewrite a Jaxpr to thread the token, if needed.""" assert has_input_token or not has_output_token if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr): return jaxpr mk_new_var = core.gensym([jaxpr]) eqns: List[core.JaxprEqn] = [] last_token_var = mk_new_var(core.abstract_token) # store the incoming token if has_input_token: invars = jaxpr.invars + [last_token_var] else: invars = jaxpr.invars eqns.append( core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var], lax.create_token_p, {}, source_info_util.current())) for eqn in jaxpr.eqns: if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params): eqns.append(eqn) else: output_token_var = mk_new_var(core.abstract_token) _rewrite_eqn(eqn, eqns, last_token_var, output_token_var, mk_new_var) last_token_var = output_token_var outvars = jaxpr.outvars + ([last_token_var] if has_output_token else []) new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns) return new_jaxpr
def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out): new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars) new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars) new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, new_invars, new_outvars, jaxpr.jaxpr.eqns, jaxpr.jaxpr.effects) return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
def ignore_errors_jaxpr(jaxpr, error): """Constructs a jaxpr which takes two extra args but ignores them.""" err_aval = core.raise_to_shaped(core.get_aval(error.err)) code_aval = core.raise_to_shaped(core.get_aval(error.code)) consts = jaxpr.consts jaxpr = jaxpr.jaxpr new_vars = core.gensym([jaxpr]) new_invars = (new_vars(err_aval), new_vars(code_aval), *jaxpr.invars) new_jaxpr = core.Jaxpr(jaxpr.constvars, new_invars, jaxpr.outvars, jaxpr.eqns) return core.ClosedJaxpr(new_jaxpr, consts)
def augment_jaxpr(jaxpr, res_indices): num_res = len(res_indices) res_vars = jaxpr.jaxpr.invars[:num_res] non_res_vars = jaxpr.jaxpr.invars[num_res:] aug_res_vars = list( util.subvals(all_res_vars, zip(res_indices, res_vars))) aug_invars = aug_res_vars + non_res_vars jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars, jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns, jaxpr.jaxpr.effects) jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts) return jaxpr_aug
def _prune_unused_inputs( jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]: used = {v for v in jaxpr.outvars if isinstance(v, core.Var)} # TODO(zhangqiaorjc): Improve the DCE algorithm by also pruning primitive # applications that do not produce used outputs. Must handle side-effecting # primitives and nested jaxpr. used.update( v for eqn in jaxpr.eqns for v in eqn.invars if isinstance(v, core.Var)) kept_const_idx, new_constvars = util.unzip2( (i, v) for i, v in enumerate(jaxpr.constvars) if v in used) kept_var_idx, new_invars = util.unzip2( (i, v) for i, v in enumerate(jaxpr.invars) if v in used) new_jaxpr = core.Jaxpr(new_constvars, new_invars, jaxpr.outvars, jaxpr.eqns) return new_jaxpr, set(kept_const_idx), set(kept_var_idx)
def tie_the_knot(typed_jaxpr): jaxpr, _, in_avals, out_avals = typed_jaxpr assert all(i == o for i, o in zip(in_avals, out_avals)) in2out = dict(zip(jaxpr.invars, jaxpr.outvars)) def replace(eqn): invars = [ in2out[i] if (isinstance(i, jc.Var) and i in in2out) else i for i in eqn.invars ] return jc.JaxprEqn(invars, eqn.outvars, eqn.primitive, eqn.params, eqn.source_info) eqns = [replace(eqn) for eqn in jaxpr.eqns] new_jaxpr = jc.Jaxpr(jaxpr.constvars, [], jaxpr.outvars, eqns) return jc.TypedJaxpr(new_jaxpr, typed_jaxpr.literals, [], typed_jaxpr.out_avals)
def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], input_token_var: core.Var, output_token_var: core.Var, mk_new_var: Callable): """Rewrite a while whose cond has outfeed""" cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict( eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) transformed_cond_jaxpr = _rewrite_typed_jaxpr(cond_jaxpr, True, True) carry_invars = eqn.invars[cond_nconsts + body_nconsts:] # pred1, token1 = rewrite(COND)(cond_consts, carry_invars, input_token) pred1_and_token1 = [ mk_new_var(ov.aval) for ov in transformed_cond_jaxpr.jaxpr.outvars ] eqns.append( core.new_jaxpr_eqn( eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var], pred1_and_token1, xla.xla_call_p, dict( call_jaxpr=transformed_cond_jaxpr.jaxpr, name="cond_before", donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)), eqn.source_info)) # Make a new cond "lambda pred, carry, token: pred" new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0]) new_cond_invars = ([new_cond_pred_invar] + [mk_new_var(cv.aval) for cv in carry_invars] + [mk_new_var(core.abstract_token)]) new_cond_jaxpr = _mk_typed_jaxpr( core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], []), []) # Make a new body: # "lambda cond_constvars, body_constvars, pred, carry, token: # carry2, token2 = rewrite(BODY)(body_constvars, carry, token) # pred2, token3 = rewrite(COND)(cond_constvars, carry2, token2) # (pred2, carry2, token3) transformed_body_jaxpr = _rewrite_typed_jaxpr(body_jaxpr, True, True) new_body_invars_cond_constvars = [ mk_new_var(v.aval) for v in eqn.invars[0:cond_nconsts] ] new_body_invars_body_constvars = [ mk_new_var(v.aval) for v in eqn.invars[cond_nconsts:cond_nconsts + body_nconsts] ] new_body_invars_pred = mk_new_var(cond_jaxpr.out_avals[0]) new_body_invars_carry = [mk_new_var(cv.aval) for cv in carry_invars] new_body_invars_token = mk_new_var(core.abstract_token) new_body_carry2 = [mk_new_var(cv.aval) for cv in carry_invars] new_body_token2 = mk_new_var(core.abstract_token) new_body_pred2 = mk_new_var(cond_jaxpr.out_avals[0]) new_body_token3 = mk_new_var(core.abstract_token) new_body_eqns = [ core.new_jaxpr_eqn( new_body_invars_body_constvars + new_body_invars_carry + [new_body_invars_token], new_body_carry2 + [new_body_token2], xla.xla_call_p, dict( call_jaxpr=transformed_body_jaxpr.jaxpr, name="body", donated_invars=(False,) * len(transformed_body_jaxpr.in_avals)), eqn.source_info), core.new_jaxpr_eqn( new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2], [new_body_pred2, new_body_token3], xla.xla_call_p, dict( call_jaxpr=transformed_cond_jaxpr.jaxpr, name="cond_body", donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)), eqn.source_info) ] new_body_jaxpr = _mk_typed_jaxpr( core.Jaxpr([], (new_body_invars_cond_constvars + new_body_invars_body_constvars + [new_body_invars_pred] + new_body_invars_carry + [new_body_invars_token]), ([new_body_pred2] + new_body_carry2 + [new_body_token3]), new_body_eqns), []) pred_out = mk_new_var(cond_jaxpr.out_avals[0]) eqns.append( core.new_jaxpr_eqn( (eqn.invars[0:cond_nconsts + body_nconsts] + [pred1_and_token1[0]] + carry_invars + [pred1_and_token1[1]]), ([pred_out] + eqn.outvars + [output_token_var]), lax.while_p, dict( cond_jaxpr=new_cond_jaxpr, cond_nconsts=0, body_jaxpr=new_body_jaxpr, body_nconsts=cond_nconsts + body_nconsts), eqn.source_info))
def _append_invars(jaxpr, avals): newvar = core.gensym([jaxpr]) return core.Jaxpr(jaxpr.constvars, jaxpr.invars + map(newvar, avals), jaxpr.outvars, jaxpr.eqns)
def _tracers_to_jaxpr(in_tracers, out_tracers): """Constructs Jaxpr given tracers for inputs and outputs. Copied from jax.interpreters.partial_eval.tracers_to_jaxpr but modified to raise an VariableError when unknown in_tracers are found, rather than the default AssertionError. Args: in_tracers: the tracers that were created for the function inputs out_tracers: the tracers that were output by the function. Returns: a triple of a `Jaxpr`, a list of constant values corresponding to the `constvars` in the returned Jaxps, and a list of environment values. The vars for the environment values have been pre-pended to the Jaxpr's `invars`. Raises: VariableError: if an unknown input tracer is found """ newvar = jax_core.gensym(None) t_to_var = {} def getvar(t): var = t_to_var.get(id(t)) if var is None: var = newvar(t.pval.get_aval()) t_to_var[id(t)] = var return var sorted_tracers = jax_util.toposort(out_tracers) invars = safe_map(getvar, in_tracers) eqns = [] env = {} consts = {} const_to_var = {} def getconstvar(c): var = const_to_var.get(id(c)) if var is None: var = newvar(jax_core.get_aval(c)) const_to_var[id(c)] = var return var processed_eqn_ids = set() for t in sorted_tracers: recipe = t.recipe if isinstance(recipe, pe.JaxprEqnRecipe): if recipe.eqn_id not in processed_eqn_ids: eqns.append(pe.recipe_to_eqn(getvar, recipe)) processed_eqn_ids.add(recipe.eqn_id) elif isinstance(recipe, pe.LambdaBinding): if not any(t is in_tracer for in_tracer in in_tracers): raise VariableError('Found unknown input tracer.') assert in_tracers, 'Lambda binding with no args' elif isinstance(recipe, pe.FreeVar): env[getvar(t)] = recipe.val elif isinstance(recipe, pe.ConstVar): v = t_to_var[id(t)] = getconstvar(recipe.val) consts[v] = recipe.val elif isinstance(recipe, jax_core.Literal): t_to_var[id(t)] = recipe elif recipe is jax_core.unit: t_to_var[id(t)] = jax_core.unitvar else: raise TypeError(recipe) env_vars, env_vals = jax_util.unzip2(env.items()) const_vars, const_vals = jax_util.unzip2(consts.items()) # The env_vars are pre-pended to the invars jaxpr = jax_core.Jaxpr(const_vars, list(it.chain(env_vars, invars)), safe_map(getvar, out_tracers), eqns) return jaxpr, const_vals, env_vals