def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool, has_output_token: bool) -> core.Jaxpr: """Rewrite a Jaxpr to thread the token, if needed.""" assert has_input_token or not has_output_token if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr): return jaxpr mk_new_var = core.gensym([jaxpr]) eqns: List[core.JaxprEqn] = [] last_token_var = mk_new_var(core.abstract_token) # store the incoming token if has_input_token: invars = jaxpr.invars + [last_token_var] else: invars = jaxpr.invars eqns.append( core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var], lax.create_token_p, {}, source_info_util.current())) for eqn in jaxpr.eqns: if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params): eqns.append(eqn) else: output_token_var = mk_new_var(core.abstract_token) _rewrite_eqn(eqn, eqns, last_token_var, output_token_var, mk_new_var) last_token_var = output_token_var outvars = jaxpr.outvars + ([last_token_var] if has_output_token else []) new_jaxpr = core.Jaxpr(jaxpr.constvars, invars, outvars, eqns) return new_jaxpr
def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool, has_output_token: bool) -> Tuple[core.Jaxpr, bool]: """Rewrite a Jaxpr to thread the token, if needed.""" assert has_input_token or not has_output_token if not has_input_token and not xla.jaxpr_uses_outfeed(jaxpr): return (jaxpr, False) max_var_count = max(_jaxpr_var_defs(jaxpr)) mk_new_id = itertools.count(start=max_var_count + 1) def mk_new_var(aval: core.AbstractValue) -> core.Var: return core.Var(next(mk_new_id), '', aval) eqns: List[core.JaxprEqn] = [] last_token_var = mk_new_var( core.abstract_token) # store the incoming token if has_input_token: invars = jaxpr.invars + [last_token_var] else: invars = jaxpr.invars eqns.append( core.new_jaxpr_eqn([jaxpr.invars[0]], [last_token_var], lax.create_token_p, {})) for eqn in jaxpr.eqns: if not xla.primitive_uses_outfeed(eqn.primitive, eqn.params): eqns.append(eqn) else: output_token_var = mk_new_var(core.abstract_token) _rewrite_eqn(eqn, eqns, last_token_var, output_token_var) last_token_var = output_token_var outvars = jaxpr.outvars + ([last_token_var] if has_output_token else []) return (core.Jaxpr(jaxpr.constvars, invars, outvars, eqns), True)
def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], input_token_var: core.Var, output_token_var: core.Var, mk_new_var: Callable): """Rewrite a while whose cond has outfeed""" cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict( eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) transformed_cond_jaxpr = _rewrite_typed_jaxpr(cond_jaxpr, True, True) carry_invars = eqn.invars[cond_nconsts + body_nconsts:] # pred1, token1 = rewrite(COND)(cond_consts, carry_invars, input_token) pred1_and_token1 = [ mk_new_var(ov.aval) for ov in transformed_cond_jaxpr.jaxpr.outvars ] eqns.append( core.new_jaxpr_eqn( eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var], pred1_and_token1, xla.xla_call_p, dict( call_jaxpr=transformed_cond_jaxpr.jaxpr, name="cond_before", donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)), eqn.source_info)) # Make a new cond "lambda pred, carry, token: pred" new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0]) new_cond_invars = ([new_cond_pred_invar] + [mk_new_var(cv.aval) for cv in carry_invars] + [mk_new_var(core.abstract_token)]) new_cond_jaxpr = _mk_typed_jaxpr( core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], []), []) # Make a new body: # "lambda cond_constvars, body_constvars, pred, carry, token: # carry2, token2 = rewrite(BODY)(body_constvars, carry, token) # pred2, token3 = rewrite(COND)(cond_constvars, carry2, token2) # (pred2, carry2, token3) transformed_body_jaxpr = _rewrite_typed_jaxpr(body_jaxpr, True, True) new_body_invars_cond_constvars = [ mk_new_var(v.aval) for v in eqn.invars[0:cond_nconsts] ] new_body_invars_body_constvars = [ mk_new_var(v.aval) for v in eqn.invars[cond_nconsts:cond_nconsts + body_nconsts] ] new_body_invars_pred = mk_new_var(cond_jaxpr.out_avals[0]) new_body_invars_carry = [mk_new_var(cv.aval) for cv in carry_invars] new_body_invars_token = mk_new_var(core.abstract_token) new_body_carry2 = [mk_new_var(cv.aval) for cv in carry_invars] new_body_token2 = mk_new_var(core.abstract_token) new_body_pred2 = mk_new_var(cond_jaxpr.out_avals[0]) new_body_token3 = mk_new_var(core.abstract_token) new_body_eqns = [ core.new_jaxpr_eqn( new_body_invars_body_constvars + new_body_invars_carry + [new_body_invars_token], new_body_carry2 + [new_body_token2], xla.xla_call_p, dict( call_jaxpr=transformed_body_jaxpr.jaxpr, name="body", donated_invars=(False,) * len(transformed_body_jaxpr.in_avals)), eqn.source_info), core.new_jaxpr_eqn( new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2], [new_body_pred2, new_body_token3], xla.xla_call_p, dict( call_jaxpr=transformed_cond_jaxpr.jaxpr, name="cond_body", donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)), eqn.source_info) ] new_body_jaxpr = _mk_typed_jaxpr( core.Jaxpr([], (new_body_invars_cond_constvars + new_body_invars_body_constvars + [new_body_invars_pred] + new_body_invars_carry + [new_body_invars_token]), ([new_body_pred2] + new_body_carry2 + [new_body_token3]), new_body_eqns), []) pred_out = mk_new_var(cond_jaxpr.out_avals[0]) eqns.append( core.new_jaxpr_eqn( (eqn.invars[0:cond_nconsts + body_nconsts] + [pred1_and_token1[0]] + carry_invars + [pred1_and_token1[1]]), ([pred_out] + eqn.outvars + [output_token_var]), lax.while_p, dict( cond_jaxpr=new_cond_jaxpr, cond_nconsts=0, body_jaxpr=new_body_jaxpr, body_nconsts=cond_nconsts + body_nconsts), eqn.source_info))
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], input_token_var: core.Var, output_token_var: core.Var, mk_new_var: Callable[[core.AbstractValue], core.Var]): """Rewrite an `eqn` and append equations to `eqns`. Assume that the current token is in `input_token_var` and the resulting token must end in `output_token_var`. """ if eqn.primitive is id_tap_p: assert "has_token_" not in eqn.params eqns.append( core.new_jaxpr_eqn(eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, dict(eqn.params, has_token_=True), eqn.source_info)) elif eqn.primitive is lax.while_p: cond_jaxpr, _, body_jaxpr, _ = util.split_dict( eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr): _rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var, mk_new_var) return eqns.append( core.new_jaxpr_eqn( eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, dict( eqn.params, body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True), cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True, False)), eqn.source_info)) elif eqn.primitive is lax.cond_p: branches, linear = util.split_dict(eqn.params, ["branches", "linear"]) index, *operands = eqn.invars new_invars = [index, *operands, input_token_var] eqns.append( core.new_jaxpr_eqn( new_invars, eqn.outvars + [output_token_var], eqn.primitive, dict( eqn.params, branches=tuple( _rewrite_typed_jaxpr(jaxpr, True, True) for jaxpr in branches), linear=(*linear, False)), eqn.source_info)) elif eqn.primitive is lax.scan_p: num_consts, num_carry, carry_jaxpr, linear, _, _, _ = util.split_dict( eqn.params, ["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length", "unroll"]) # We add the token right at the end of carry nr_const_and_carry = num_consts + num_carry new_invars = eqn.invars[0:nr_const_and_carry] + [ input_token_var ] + eqn.invars[nr_const_and_carry:] new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True) # The rewrite has put the token at end, it has to be at end of carry new_jaxpr_invars = new_jaxpr.jaxpr.invars new_jaxpr_invars = ( new_jaxpr_invars[0:nr_const_and_carry] + [new_jaxpr_invars[-1]] + new_jaxpr_invars[nr_const_and_carry:-1]) new_jaxpr.jaxpr.invars = new_jaxpr_invars new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars] new_jaxpr_outvars = new_jaxpr.jaxpr.outvars new_jaxpr_outvars = ( new_jaxpr_outvars[0:num_carry] + [new_jaxpr_outvars[-1]] + new_jaxpr_outvars[num_carry:-1]) new_jaxpr.jaxpr.outvars = new_jaxpr_outvars new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars] eqns.append( core.new_jaxpr_eqn( new_invars, # Output token is at the end of carry result eqn.outvars[0:num_carry] + [output_token_var] + eqn.outvars[num_carry:], eqn.primitive, dict( eqn.params, jaxpr=new_jaxpr, num_carry=num_carry + 1, linear=linear + (False,)), eqn.source_info)) elif eqn.primitive is xla.xla_call_p: call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) eqns.append( core.new_jaxpr_eqn( eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, dict( eqn.params, call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True), donated_invars=eqn.params["donated_invars"] + (False,) ), eqn.source_info)) elif eqn.primitive is custom_derivatives.custom_jvp_call_jaxpr_p: fun_jaxpr = eqn.params["fun_jaxpr"] new_invars = [*eqn.invars, input_token_var] def unreachable_thunk(): assert False, "Should not be reached" eqns.append( core.new_jaxpr_eqn( new_invars, eqn.outvars + [output_token_var], eqn.primitive, dict( eqn.params, fun_jaxpr=_rewrite_typed_jaxpr(fun_jaxpr, True, True), jvp_jaxpr_thunk=unreachable_thunk ), eqn.source_info)) elif eqn.primitive is custom_derivatives.custom_vjp_call_jaxpr_p: fun_jaxpr = eqn.params["fun_jaxpr"] new_invars = [*eqn.invars, input_token_var] def unreachable_thunk(): assert False, "Should not be reached" eqns.append( core.new_jaxpr_eqn( new_invars, eqn.outvars + [output_token_var], eqn.primitive, dict( eqn.params, fun_jaxpr=_rewrite_typed_jaxpr(fun_jaxpr, True, True), fwd_jaxpr_thunk=unreachable_thunk, # The following are illegal values for the parameters, they # should not be needed because this rewrite is just before # compilation to XLA, which does not use those parameters. bwd="illegal param", out_trees="illegal param" ), eqn.source_info)) else: raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], input_token_var: core.Var, output_token_var: core.Var): """Rewrite an `eqn` and append equations to `eqns`. Assume that the current token is in `input_token_var` and the resulting token must end in `output_token_var`.""" if eqn.primitive is id_tap_p: eqns.append( core.new_jaxpr_eqn(eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, eqn.params)) elif eqn.primitive is lax.while_p: cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict( eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr): # TODO(necula): implement tapping from the conditional of a while raise NotImplementedError( "outfeed not supported in the conditional of a while") eqns.append( core.new_jaxpr_eqn( eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, dict(eqn.params, body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True)[0], cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True, False)[0]))) elif eqn.primitive is lax.cond_p: true_jaxpr, false_jaxpr, linear = util.split_dict( eqn.params, ["true_jaxpr", "false_jaxpr", "linear"]) nr_true_invars = len(true_jaxpr.jaxpr.invars) pred, true_invars, false_invars = util.split_list( eqn.invars, [1, nr_true_invars]) new_invars = pred + true_invars + [input_token_var] + false_invars + [ input_token_var ] eqns.append( core.new_jaxpr_eqn( new_invars, eqn.outvars + [output_token_var], eqn.primitive, dict(eqn.params, true_jaxpr=_rewrite_typed_jaxpr(true_jaxpr, True, True)[0], false_jaxpr=_rewrite_typed_jaxpr(false_jaxpr, True, True)[0], linear=linear + (False, False)))) elif eqn.primitive is lax.scan_p: num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict( eqn.params, [ "num_consts", "num_carry", "jaxpr", "linear", "reverse", "length" ]) # We add the token right at the end of carry nr_const_and_carry = num_consts + num_carry new_invars = eqn.invars[0:nr_const_and_carry] + [ input_token_var ] + eqn.invars[nr_const_and_carry:] new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)[0] # The rewrite has put the token at end, it has to be at end of carry new_jaxpr_invars = new_jaxpr.jaxpr.invars new_jaxpr_invars = (new_jaxpr_invars[0:nr_const_and_carry] + [new_jaxpr_invars[-1]] + new_jaxpr_invars[nr_const_and_carry:-1]) new_jaxpr.jaxpr.invars = new_jaxpr_invars new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars] new_jaxpr_outvars = new_jaxpr.jaxpr.outvars new_jaxpr_outvars = (new_jaxpr_outvars[0:num_carry] + [new_jaxpr_outvars[-1]] + new_jaxpr_outvars[num_carry:-1]) new_jaxpr.jaxpr.outvars = new_jaxpr_outvars new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars] eqns.append( core.new_jaxpr_eqn( new_invars, # Output token is at the end of carry result eqn.outvars[0:num_carry] + [output_token_var] + eqn.outvars[num_carry:], eqn.primitive, dict(eqn.params, jaxpr=new_jaxpr, num_carry=num_carry + 1, linear=linear + (False, )))) elif eqn.primitive is xla.xla_call_p: call_jaxpr = eqn.params["call_jaxpr"] eqns.append( core.new_jaxpr_eqn( eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, dict(eqn.params, call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True)[0]))) else: raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], input_token_var: core.Var, output_token_var: core.Var, mk_new_var: Callable[[core.AbstractValue], core.Var]): """Rewrite an `eqn` and append equations to `eqns`. Assume that the current token is in `input_token_var` and the resulting token must end in `output_token_var`.""" if eqn.primitive is id_tap_p: eqns.append( core.new_jaxpr_eqn(eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, eqn.params, eqn.source_info)) elif eqn.primitive is lax.while_p: cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict( eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr): _rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var, mk_new_var) return eqns.append( core.new_jaxpr_eqn( eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, dict(eqn.params, body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True)[0], cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True, False)[0]), eqn.source_info)) elif eqn.primitive is lax.cond_p: branches, linear = util.split_dict(eqn.params, ["branches", "linear"]) index, *operands = eqn.invars new_invars = [index, *operands, input_token_var] eqns.append( core.new_jaxpr_eqn( new_invars, eqn.outvars + [output_token_var], eqn.primitive, dict(eqn.params, branches=tuple( _rewrite_typed_jaxpr(jaxpr, True, True)[0] for jaxpr in branches), linear=(*linear, False)), eqn.source_info)) elif eqn.primitive is lax.scan_p: num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict( eqn.params, [ "num_consts", "num_carry", "jaxpr", "linear", "reverse", "length" ]) # We add the token right at the end of carry nr_const_and_carry = num_consts + num_carry new_invars = eqn.invars[0:nr_const_and_carry] + [ input_token_var ] + eqn.invars[nr_const_and_carry:] new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)[0] # The rewrite has put the token at end, it has to be at end of carry new_jaxpr_invars = new_jaxpr.jaxpr.invars new_jaxpr_invars = (new_jaxpr_invars[0:nr_const_and_carry] + [new_jaxpr_invars[-1]] + new_jaxpr_invars[nr_const_and_carry:-1]) new_jaxpr.jaxpr.invars = new_jaxpr_invars new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars] new_jaxpr_outvars = new_jaxpr.jaxpr.outvars new_jaxpr_outvars = (new_jaxpr_outvars[0:num_carry] + [new_jaxpr_outvars[-1]] + new_jaxpr_outvars[num_carry:-1]) new_jaxpr.jaxpr.outvars = new_jaxpr_outvars new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars] eqns.append( core.new_jaxpr_eqn( new_invars, # Output token is at the end of carry result eqn.outvars[0:num_carry] + [output_token_var] + eqn.outvars[num_carry:], eqn.primitive, dict(eqn.params, jaxpr=new_jaxpr, num_carry=num_carry + 1, linear=linear + (False, )), eqn.source_info)) elif eqn.primitive is xla.xla_call_p: call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) eqns.append( core.new_jaxpr_eqn( eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, dict(eqn.params, call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True)[0]), eqn.source_info)) else: raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")