def _scan_transpose(cts, *args, **kwargs): forward, length, num_consts, num_carry, jaxpr, linear = split_dict( kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"]) # we can only transpose scans for which the nonlinear values appear in xs consts_lin, init_lin, xs_lin = split_list(linear, [num_consts, num_carry]) num_lin = sum(xs_lin) if not all(consts_lin) or not all(init_lin) or not all(xs_lin[:num_lin]): raise NotImplementedError consts, init, xs, res = split_list(args, [num_consts, num_carry, num_lin]) assert not any(r is ad.undefined_primal for r in res) carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) ys_avals = _map(partial(_promote_aval_rank, length), y_avals) ct_carry, ct_ys = split_list(cts, [num_carry]) ct_carry = _map(ad.instantiate_zeros_aval, carry_avals, ct_carry) ct_ys = _map(ad.instantiate_zeros_aval, ys_avals, ct_ys) ct_consts = _map(ad_util.zeros_like_aval, jaxpr.in_avals[:num_consts]) # jaxpr :: [T d] -> [T c] -> [T a, res] -> ([T c], [T b]) # jaxpr_trans :: [] -> [CT d, CT c] -> [CT b, res] -> ([CT d, CT c], [CT a]) jaxpr_trans = _transpose_jaxpr(num_consts, len(res), jaxpr) linear_trans = ([True] * (len(ct_consts) + len(ct_carry) + len(ct_ys)) + [False] * len(res)) outs = scan_p.bind( *(ct_consts + ct_carry + ct_ys + res), forward=not forward, length=length, jaxpr=jaxpr_trans, num_consts=0, num_carry=num_consts+num_carry, linear=linear_trans) ct_consts, ct_init, ct_xs = split_list(outs, [num_consts, num_carry]) return ct_consts + ct_init + ct_xs + [None] * len(res)
def _custom_linear_solve_impl(*args, **kwargs): const_lengths, jaxprs, tree = split_dict( kwargs, ['const_lengths', 'jaxprs', 'tree']) params, b = _split_linear_solve_args(args, const_lengths) x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b)) _check_shapes('solve', 'b', x, b, tree) return x
def scan_bind(*args, **kwargs): forward, length, num_consts, num_carry, jaxpr, linear = split_dict( kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"]) consts, init, xs = split_list(args, [num_consts, num_carry]) assert len(linear) == len(args) # check that args match input types consts_avals, init_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) xs_avals = _map(partial(_promote_aval_rank, length), x_avals) assert all(_map(typecheck, consts_avals, consts)) assert all(_map(typecheck, init_avals, init)) # assert all(_map(typecheck, xs_avals, xs)) # check that output carry type matches input carry type carry_avals, _ = split_list(jaxpr.out_avals, [num_carry]) assert all(_map(typematch, init_avals, carry_avals)) # check that the data flow is sensible core.check_jaxpr(jaxpr.jaxpr) return core.Primitive.bind(scan_p, *args, forward=forward, length=length, jaxpr=jaxpr, num_consts=num_consts, num_carry=num_carry, linear=linear)
def _cond_translation_rule(c, axis_env, pred, *args, **kwargs): backend = kwargs.pop("backend", None) true_jaxpr, false_jaxpr, true_nconsts, false_nconsts = split_dict( kwargs, ["true_jaxpr", "false_jaxpr", "true_nconsts", "false_nconsts"]) true_nops = len(true_jaxpr.in_avals) - true_nconsts true_consts, true_ops, false_consts, false_ops = split_list( args, [true_nconsts, true_nops, false_nconsts]) def make_computation(name, jaxpr, op_shape): c = xb.make_computation_builder(name) op = c.ParameterWithShape(op_shape) ops = [c.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))] out = c.Call( xla.jaxpr_computation(jaxpr.jaxpr, backend, axis_env, jaxpr.literals, (), *_map(c.GetShape, ops)), ops) return c.Build(out) true_op = c.Tuple(*(true_consts + true_ops)) true_c = make_computation("true_comp", true_jaxpr, c.GetShape(true_op)) false_op = c.Tuple(*(false_consts + false_ops)) false_c = make_computation("false_comp", false_jaxpr, c.GetShape(false_op)) return c.Conditional(pred, true_op, true_c, false_op, false_c)
def process_parametrized(self, primitive, *flat_inputs, **kwargs): in_tree, out_tree_container = split_dict(kwargs, ['in_tree', 'out_tree_container']) inputs = tree_unflatten(in_tree, flat_inputs) outputs = self._process_parametrized_nonflat(primitive, *inputs) flat_outputs, out_tree = tree_flatten(outputs) out_tree_container.append(out_tree) return flat_outputs
def _while_loop_translation_rule(c, axis_env, *args, **kwargs): cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts = split_dict( kwargs, ["cond_jaxpr", "body_jaxpr", "cond_nconsts", "body_nconsts"]) cond_consts, body_consts, init_vals = split_list( args, [cond_nconsts, body_nconsts]) batched = bool(cond_jaxpr.out_avals[0].shape) # Since jaxprs don't have tuples and have multiple return values, but we need # the HLO While loop to take a single tuple input and output a single boolean # (for the cond computation) or a single tuple output (for the body # computation), we build XLA computations that handle the tuple munging before # generating a Call into the computations formed from the jaxprs. init_carry = c.Tuple(*(cond_consts + body_consts + init_vals)) cond_c = xb.make_computation_builder("cond_computation") cond_carry = cond_c.ParameterWithShape(c.GetShape(init_carry)) cond_carry_elts = [ cond_c.GetTupleElement(cond_carry, i) for i in range(len(args)) ] x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts]) cond_outs = cond_c.Call( xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env, cond_jaxpr.literals, (), *_map(cond_c.GetShape, x + z)), x + z) pred = cond_c.GetTupleElement(cond_outs, 0) if batched: scalar = xla_client.Shape.array_shape(onp.dtype(onp.bool_), ()) or_ = xla.primitive_computation(lax.or_p, scalar, scalar) pred = cond_c.Reduce(pred, cond_c.Constant(onp.array(False)), or_, list(range(cond_jaxpr.out_avals[0].ndim))) body_c = xb.make_computation_builder("body_computation") body_carry = body_c.ParameterWithShape(c.GetShape(init_carry)) body_carry_elts = [ body_c.GetTupleElement(body_carry, i) for i in range(len(args)) ] x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts]) body_out = body_c.Call( xla.jaxpr_computation(body_jaxpr.jaxpr, axis_env, body_jaxpr.literals, (), *_map(body_c.GetShape, y + z)), y + z) new_z = [ body_c.GetTupleElement(body_out, i) for i in range(len(init_vals)) ] if batched: body_cond_outs = body_c.Call( xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env, cond_jaxpr.literals, (), *_map(body_c.GetShape, x + z)), x + z) body_pred = body_c.GetTupleElement(body_cond_outs, 0) new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z) assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape, z) # no broadcast new_carry = body_c.Tuple(*(x + y + new_z)) ans = c.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry) ans_elts = [c.GetTupleElement(ans, i) for i in range(len(args))] _, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts]) return c.Tuple(*z)
def abstract_eval(self, *avals, **kwargs): in_tree, out_tree_container = split_dict(kwargs, ['in_tree', 'out_tree_container']) flat_outs_fun, out_tree_thunk = flatten_fun_nokwargs(self._wrapped_example_outputs_fun, in_tree) # populates out_tree_thunk, so that it returns the output tree: _, flat_outs, _ = _instantiated_trace_to_jaxpr(flat_outs_fun, avals) # return out_tree via container: out_tree_container.append(out_tree_thunk()) return flat_outs
def _cond_impl(pred, *args, **kwargs): true_jaxpr, false_jaxpr, true_nconsts, false_nconsts = split_dict( kwargs, ["true_jaxpr", "false_jaxpr", "true_nconsts", "false_nconsts"]) true_consts, true_ops, false_consts, false_ops = split_list( args, [true_nconsts, len(true_jaxpr.in_avals), false_nconsts]) if pred: return core.jaxpr_as_fun(true_jaxpr)(*(true_consts + true_ops)) else: return core.jaxpr_as_fun(false_jaxpr)(*(false_consts + false_ops))
def _custom_linear_solve_transpose_rule(cotangent, *primals, **kwargs): const_lengths, jaxprs, tree = split_dict( kwargs, ['const_lengths', 'jaxprs', 'tree']) if jaxprs.transpose_solve is None: raise TypeError('transpose_solve required for backwards mode automatic ' 'differentiation of custom_linear_solve') params, b = _split_linear_solve_args(primals, const_lengths) assert b == [ad.undefined_primal] * len(b) cotangent_b = custom_linear_solve_p.bind( *(_flatten(params.transpose()) + cotangent), const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose(), tree=tree) return [None] * sum(const_lengths) + cotangent_b
def _scan_partial_eval(trace, *tracers, **kwargs): forward, length, num_consts, num_carry, jaxpr, linear = split_dict( kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"]) num_xs = len(jaxpr.in_avals) - num_carry - num_consts num_ys = len(jaxpr.out_avals) - num_carry unknowns = original_unknowns = [t.pval[0] is not None for t in tracers] const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry]) carry_uk = init_uk for _ in range(1000): unknowns = const_uk + carry_uk + xs_uk jaxpr_1, jaxpr_2, out_uk = pe.partial_eval_jaxpr( jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys) carry_uk_out, ys_uk = out_uk[:num_carry], out_uk[num_carry:] if carry_uk_out == carry_uk: break else: carry_uk = carry_uk_out else: raise FixedPointError in_consts = [core.unit if uk else t.pval[1] for uk, t in zip(unknowns, tracers)] new_tracers = [trace.instantiate_const(t) if uk else trace.new_instantiated_literal(core.unit) for uk, t in zip(unknowns, tracers)] carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry]) ys_avals = _map(partial(_promote_aval_rank, length), y_avals) out_avals = carry_avals + ys_avals out_pvs = [aval if uk else None for aval, uk in zip(out_avals, out_uk)] linear_1 = [lin or uk for uk, lin in zip(unknowns, linear)] out_flat = scan_p.bind( *in_consts, forward=forward, length=length, jaxpr=jaxpr_1, num_consts=num_consts, num_carry=num_carry, linear=linear_1) out_carry, ys, residuals = split_list(out_flat, [num_carry, num_ys]) out_consts = out_carry + ys residual_tracers = _map(trace.new_instantiated_const, residuals) out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None) for pv, const in zip(out_pvs, out_consts)] linear_2 = ([lin or not uk for uk, lin in zip(unknowns, linear)] + [False] * len(residual_tracers)) eqn = pe.new_jaxpr_eqn(new_tracers + residual_tracers, out_tracers, scan_p, (), dict(forward=forward, length=length, jaxpr=jaxpr_2, num_consts=num_consts, num_carry=num_carry, linear=linear_2)) for t in out_tracers: t.recipe = eqn return out_tracers
def _custom_cell_scan_impl(flat_cell, *args, **kwargs): """lax_control_flow._scan_impl, but allowing for a custom cell function.""" reverse, length, num_consts, num_carry, jaxpr, linear, unroll = split_dict( kwargs, ["reverse", "length", "num_consts", "num_carry", "jaxpr", "linear", "unroll"]) consts, init, xs = split_list(args, [num_consts, num_carry]) _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) cell_args = consts + init + map(partial(_index_array, 0), x_avals, xs) jaxpr, new_consts = _flat_initial_style_jaxpr(wrap_init(flat_cell), _abstractified(cell_args)) args = list(new_consts) + init + xs kwargs['jaxpr'] = jaxpr kwargs['num_consts'] = len(new_consts) kwargs['linear'] = (False,) * len(args) return scan_p.bind(*args, **kwargs)
def _scan_impl(*args, **kwargs): forward, length, num_consts, num_carry, jaxpr, linear = split_dict( kwargs, ["forward", "length", "num_consts", "num_carry", "jaxpr", "linear"]) consts, init, xs = split_list(args, [num_consts, num_carry]) _, _, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) _, y_avals = split_list(jaxpr.out_avals, [num_carry]) def body_fun(i, vals): i = i if forward else length - i - 1 carry, ys = split_list(vals, [num_carry]) x = _map(partial(_index_array, i), x_avals, xs) out_flat = core.jaxpr_as_fun(jaxpr)(*(consts + carry + x)) carry_out, y_updates = split_list(out_flat, [num_carry]) ys_out = _map(partial(_update_array, i), y_avals, ys, y_updates) return carry_out + ys_out ys_init = _map(partial(_empty_array, length), y_avals) return fori_loop(0, length, body_fun, init + ys_init)
def _root_impl(*args, **kwargs): num_consts, jaxpr, solve, _ = split_dict( kwargs, ['num_consts', 'jaxpr', 'solve', 'tangent_solve']) params, initial_guess = split_list(args, [num_consts]) f = partial(core.jaxpr_as_fun(jaxpr), *params) return solve(f, *initial_guess)
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], input_token_var: core.Var, output_token_var: core.Var, mk_new_var: Callable[[core.AbstractValue], core.Var]): """Rewrite an `eqn` and append equations to `eqns`. Assume that the current token is in `input_token_var` and the resulting token must end in `output_token_var`.""" if eqn.primitive is id_tap_p: eqns.append( core.new_jaxpr_eqn(eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, eqn.params, eqn.source_info)) elif eqn.primitive is lax.while_p: cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict( eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr): _rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var, mk_new_var) return eqns.append( core.new_jaxpr_eqn( eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, dict(eqn.params, body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True)[0], cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True, False)[0]), eqn.source_info)) elif eqn.primitive is lax.cond_p: branches, linear = util.split_dict(eqn.params, ["branches", "linear"]) index, *operands = eqn.invars new_invars = [index, *operands, input_token_var] eqns.append( core.new_jaxpr_eqn( new_invars, eqn.outvars + [output_token_var], eqn.primitive, dict(eqn.params, branches=tuple( _rewrite_typed_jaxpr(jaxpr, True, True)[0] for jaxpr in branches), linear=(*linear, False)), eqn.source_info)) elif eqn.primitive is lax.scan_p: num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict( eqn.params, [ "num_consts", "num_carry", "jaxpr", "linear", "reverse", "length" ]) # We add the token right at the end of carry nr_const_and_carry = num_consts + num_carry new_invars = eqn.invars[0:nr_const_and_carry] + [ input_token_var ] + eqn.invars[nr_const_and_carry:] new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)[0] # The rewrite has put the token at end, it has to be at end of carry new_jaxpr_invars = new_jaxpr.jaxpr.invars new_jaxpr_invars = (new_jaxpr_invars[0:nr_const_and_carry] + [new_jaxpr_invars[-1]] + new_jaxpr_invars[nr_const_and_carry:-1]) new_jaxpr.jaxpr.invars = new_jaxpr_invars new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars] new_jaxpr_outvars = new_jaxpr.jaxpr.outvars new_jaxpr_outvars = (new_jaxpr_outvars[0:num_carry] + [new_jaxpr_outvars[-1]] + new_jaxpr_outvars[num_carry:-1]) new_jaxpr.jaxpr.outvars = new_jaxpr_outvars new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars] eqns.append( core.new_jaxpr_eqn( new_invars, # Output token is at the end of carry result eqn.outvars[0:num_carry] + [output_token_var] + eqn.outvars[num_carry:], eqn.primitive, dict(eqn.params, jaxpr=new_jaxpr, num_carry=num_carry + 1, linear=linear + (False, )), eqn.source_info)) elif eqn.primitive is xla.xla_call_p: call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) eqns.append( core.new_jaxpr_eqn( eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, dict(eqn.params, call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True)[0]), eqn.source_info)) else: raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], input_token_var: core.Var, output_token_var: core.Var, mk_new_var: Callable[[core.AbstractValue], core.Var]): """Rewrite an `eqn` and append equations to `eqns`. Assume that the current token is in `input_token_var` and the resulting token must end in `output_token_var`. """ if eqn.primitive is id_tap_p: assert "has_token_" not in eqn.params eqns.append( core.new_jaxpr_eqn(eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, dict(eqn.params, has_token_=True), eqn.source_info)) elif eqn.primitive is lax.while_p: cond_jaxpr, _, body_jaxpr, _ = util.split_dict( eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr): _rewrite_while_outfeed_cond(eqn, eqns, input_token_var, output_token_var, mk_new_var) return eqns.append( core.new_jaxpr_eqn( eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, dict( eqn.params, body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True), cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True, False)), eqn.source_info)) elif eqn.primitive is lax.cond_p: branches, linear = util.split_dict(eqn.params, ["branches", "linear"]) index, *operands = eqn.invars new_invars = [index, *operands, input_token_var] eqns.append( core.new_jaxpr_eqn( new_invars, eqn.outvars + [output_token_var], eqn.primitive, dict( eqn.params, branches=tuple( _rewrite_typed_jaxpr(jaxpr, True, True) for jaxpr in branches), linear=(*linear, False)), eqn.source_info)) elif eqn.primitive is lax.scan_p: num_consts, num_carry, carry_jaxpr, linear, _, _, _ = util.split_dict( eqn.params, ["num_consts", "num_carry", "jaxpr", "linear", "reverse", "length", "unroll"]) # We add the token right at the end of carry nr_const_and_carry = num_consts + num_carry new_invars = eqn.invars[0:nr_const_and_carry] + [ input_token_var ] + eqn.invars[nr_const_and_carry:] new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True) # The rewrite has put the token at end, it has to be at end of carry new_jaxpr_invars = new_jaxpr.jaxpr.invars new_jaxpr_invars = ( new_jaxpr_invars[0:nr_const_and_carry] + [new_jaxpr_invars[-1]] + new_jaxpr_invars[nr_const_and_carry:-1]) new_jaxpr.jaxpr.invars = new_jaxpr_invars new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars] new_jaxpr_outvars = new_jaxpr.jaxpr.outvars new_jaxpr_outvars = ( new_jaxpr_outvars[0:num_carry] + [new_jaxpr_outvars[-1]] + new_jaxpr_outvars[num_carry:-1]) new_jaxpr.jaxpr.outvars = new_jaxpr_outvars new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars] eqns.append( core.new_jaxpr_eqn( new_invars, # Output token is at the end of carry result eqn.outvars[0:num_carry] + [output_token_var] + eqn.outvars[num_carry:], eqn.primitive, dict( eqn.params, jaxpr=new_jaxpr, num_carry=num_carry + 1, linear=linear + (False,)), eqn.source_info)) elif eqn.primitive is xla.xla_call_p: call_jaxpr = cast(core.Jaxpr, eqn.params["call_jaxpr"]) eqns.append( core.new_jaxpr_eqn( eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, dict( eqn.params, call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True), donated_invars=eqn.params["donated_invars"] + (False,) ), eqn.source_info)) elif eqn.primitive is custom_derivatives.custom_jvp_call_jaxpr_p: fun_jaxpr = eqn.params["fun_jaxpr"] new_invars = [*eqn.invars, input_token_var] def unreachable_thunk(): assert False, "Should not be reached" eqns.append( core.new_jaxpr_eqn( new_invars, eqn.outvars + [output_token_var], eqn.primitive, dict( eqn.params, fun_jaxpr=_rewrite_typed_jaxpr(fun_jaxpr, True, True), jvp_jaxpr_thunk=unreachable_thunk ), eqn.source_info)) elif eqn.primitive is custom_derivatives.custom_vjp_call_jaxpr_p: fun_jaxpr = eqn.params["fun_jaxpr"] new_invars = [*eqn.invars, input_token_var] def unreachable_thunk(): assert False, "Should not be reached" eqns.append( core.new_jaxpr_eqn( new_invars, eqn.outvars + [output_token_var], eqn.primitive, dict( eqn.params, fun_jaxpr=_rewrite_typed_jaxpr(fun_jaxpr, True, True), fwd_jaxpr_thunk=unreachable_thunk, # The following are illegal values for the parameters, they # should not be needed because this rewrite is just before # compilation to XLA, which does not use those parameters. bwd="illegal param", out_trees="illegal param" ), eqn.source_info)) else: raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")
def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], input_token_var: core.Var, output_token_var: core.Var, mk_new_var: Callable): """Rewrite a while whose cond has outfeed""" cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict( eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) transformed_cond_jaxpr = _rewrite_typed_jaxpr(cond_jaxpr, True, True) carry_invars = eqn.invars[cond_nconsts + body_nconsts:] # pred1, token1 = rewrite(COND)(cond_consts, carry_invars, input_token) pred1_and_token1 = [ mk_new_var(ov.aval) for ov in transformed_cond_jaxpr.jaxpr.outvars ] eqns.append( core.new_jaxpr_eqn( eqn.invars[0:cond_nconsts] + carry_invars + [input_token_var], pred1_and_token1, xla.xla_call_p, dict( call_jaxpr=transformed_cond_jaxpr.jaxpr, name="cond_before", donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)), eqn.source_info)) # Make a new cond "lambda pred, carry, token: pred" new_cond_pred_invar = mk_new_var(cond_jaxpr.out_avals[0]) new_cond_invars = ([new_cond_pred_invar] + [mk_new_var(cv.aval) for cv in carry_invars] + [mk_new_var(core.abstract_token)]) new_cond_jaxpr = _mk_typed_jaxpr( core.Jaxpr([], new_cond_invars, [new_cond_pred_invar], []), []) # Make a new body: # "lambda cond_constvars, body_constvars, pred, carry, token: # carry2, token2 = rewrite(BODY)(body_constvars, carry, token) # pred2, token3 = rewrite(COND)(cond_constvars, carry2, token2) # (pred2, carry2, token3) transformed_body_jaxpr = _rewrite_typed_jaxpr(body_jaxpr, True, True) new_body_invars_cond_constvars = [ mk_new_var(v.aval) for v in eqn.invars[0:cond_nconsts] ] new_body_invars_body_constvars = [ mk_new_var(v.aval) for v in eqn.invars[cond_nconsts:cond_nconsts + body_nconsts] ] new_body_invars_pred = mk_new_var(cond_jaxpr.out_avals[0]) new_body_invars_carry = [mk_new_var(cv.aval) for cv in carry_invars] new_body_invars_token = mk_new_var(core.abstract_token) new_body_carry2 = [mk_new_var(cv.aval) for cv in carry_invars] new_body_token2 = mk_new_var(core.abstract_token) new_body_pred2 = mk_new_var(cond_jaxpr.out_avals[0]) new_body_token3 = mk_new_var(core.abstract_token) new_body_eqns = [ core.new_jaxpr_eqn( new_body_invars_body_constvars + new_body_invars_carry + [new_body_invars_token], new_body_carry2 + [new_body_token2], xla.xla_call_p, dict( call_jaxpr=transformed_body_jaxpr.jaxpr, name="body", donated_invars=(False,) * len(transformed_body_jaxpr.in_avals)), eqn.source_info), core.new_jaxpr_eqn( new_body_invars_cond_constvars + new_body_carry2 + [new_body_token2], [new_body_pred2, new_body_token3], xla.xla_call_p, dict( call_jaxpr=transformed_cond_jaxpr.jaxpr, name="cond_body", donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)), eqn.source_info) ] new_body_jaxpr = _mk_typed_jaxpr( core.Jaxpr([], (new_body_invars_cond_constvars + new_body_invars_body_constvars + [new_body_invars_pred] + new_body_invars_carry + [new_body_invars_token]), ([new_body_pred2] + new_body_carry2 + [new_body_token3]), new_body_eqns), []) pred_out = mk_new_var(cond_jaxpr.out_avals[0]) eqns.append( core.new_jaxpr_eqn( (eqn.invars[0:cond_nconsts + body_nconsts] + [pred1_and_token1[0]] + carry_invars + [pred1_and_token1[1]]), ([pred_out] + eqn.outvars + [output_token_var]), lax.while_p, dict( cond_jaxpr=new_cond_jaxpr, cond_nconsts=0, body_jaxpr=new_body_jaxpr, body_nconsts=cond_nconsts + body_nconsts), eqn.source_info))
def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], input_token_var: core.Var, output_token_var: core.Var): """Rewrite an `eqn` and append equations to `eqns`. Assume that the current token is in `input_token_var` and the resulting token must end in `output_token_var`.""" if eqn.primitive is id_tap_p: eqns.append( core.new_jaxpr_eqn(eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, eqn.params)) elif eqn.primitive is lax.while_p: cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts = util.split_dict( eqn.params, ["cond_jaxpr", "cond_nconsts", "body_jaxpr", "body_nconsts"]) if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr): # TODO(necula): implement tapping from the conditional of a while raise NotImplementedError( "outfeed not supported in the conditional of a while") eqns.append( core.new_jaxpr_eqn( eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, dict(eqn.params, body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True)[0], cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True, False)[0]))) elif eqn.primitive is lax.cond_p: true_jaxpr, false_jaxpr, linear = util.split_dict( eqn.params, ["true_jaxpr", "false_jaxpr", "linear"]) nr_true_invars = len(true_jaxpr.jaxpr.invars) pred, true_invars, false_invars = util.split_list( eqn.invars, [1, nr_true_invars]) new_invars = pred + true_invars + [input_token_var] + false_invars + [ input_token_var ] eqns.append( core.new_jaxpr_eqn( new_invars, eqn.outvars + [output_token_var], eqn.primitive, dict(eqn.params, true_jaxpr=_rewrite_typed_jaxpr(true_jaxpr, True, True)[0], false_jaxpr=_rewrite_typed_jaxpr(false_jaxpr, True, True)[0], linear=linear + (False, False)))) elif eqn.primitive is lax.scan_p: num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict( eqn.params, [ "num_consts", "num_carry", "jaxpr", "linear", "reverse", "length" ]) # We add the token right at the end of carry nr_const_and_carry = num_consts + num_carry new_invars = eqn.invars[0:nr_const_and_carry] + [ input_token_var ] + eqn.invars[nr_const_and_carry:] new_jaxpr = _rewrite_typed_jaxpr(carry_jaxpr, True, True)[0] # The rewrite has put the token at end, it has to be at end of carry new_jaxpr_invars = new_jaxpr.jaxpr.invars new_jaxpr_invars = (new_jaxpr_invars[0:nr_const_and_carry] + [new_jaxpr_invars[-1]] + new_jaxpr_invars[nr_const_and_carry:-1]) new_jaxpr.jaxpr.invars = new_jaxpr_invars new_jaxpr.in_avals = [v.aval for v in new_jaxpr_invars] new_jaxpr_outvars = new_jaxpr.jaxpr.outvars new_jaxpr_outvars = (new_jaxpr_outvars[0:num_carry] + [new_jaxpr_outvars[-1]] + new_jaxpr_outvars[num_carry:-1]) new_jaxpr.jaxpr.outvars = new_jaxpr_outvars new_jaxpr.out_avals = [v.aval for v in new_jaxpr_outvars] eqns.append( core.new_jaxpr_eqn( new_invars, # Output token is at the end of carry result eqn.outvars[0:num_carry] + [output_token_var] + eqn.outvars[num_carry:], eqn.primitive, dict(eqn.params, jaxpr=new_jaxpr, num_carry=num_carry + 1, linear=linear + (False, )))) elif eqn.primitive is xla.xla_call_p: call_jaxpr = eqn.params["call_jaxpr"] eqns.append( core.new_jaxpr_eqn( eqn.invars + [input_token_var], eqn.outvars + [output_token_var], eqn.primitive, dict(eqn.params, call_jaxpr=_rewrite_jaxpr(call_jaxpr, True, True)[0]))) else: raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")