def _while_loop_translation_rule(c, init_val, cond_consts, body_consts, aval_out, cond_jaxpr, body_jaxpr): loop_carry = c.Tuple(init_val, cond_consts, body_consts) shape = c.GetShape(loop_carry) loop_carry_var = pe.Var(0, "loop_carry") outvar = pe.Var(0, "loop_carry_out") cond_var = pe.Var(0, "cond_consts") body_var = pe.Var(0, "body_consts") assert len(cond_jaxpr.invars) == 1 cond_jaxpr_converted = cond_jaxpr.copy() cond_jaxpr_converted.constvars = [] cond_jaxpr_converted.invars = [loop_carry_var] cond_jaxpr_converted.eqns = ( [_unpack_eqn(loop_carry_var, [cond_jaxpr.invars[0], cond_var, body_var]), _unpack_eqn(cond_var, cond_jaxpr.constvars)] + list(cond_jaxpr.eqns)) assert len(body_jaxpr.invars) == 1 body_jaxpr_converted = body_jaxpr.copy() body_jaxpr_converted.constvars = [] body_jaxpr_converted.invars = [loop_carry_var] body_jaxpr_converted.outvar = outvar body_jaxpr_converted.eqns = ( [_unpack_eqn(loop_carry_var, [body_jaxpr.invars[0], cond_var, body_var]), _unpack_eqn(body_var, body_jaxpr.constvars)] + list(body_jaxpr.eqns) + [_pack_eqn([body_jaxpr.outvar, cond_var, body_var], outvar)]) cond_computation = xla.jaxpr_computation(cond_jaxpr_converted, (), (), shape) body_computation = xla.jaxpr_computation(body_jaxpr_converted, (), (), shape) full_ans = c.While(cond_computation, body_computation, loop_carry) return c.GetTupleElement(full_ans, 0)
def make_computation(jaxpr, operand): assert len(jaxpr.invars) == 1 arg_var = pe.Var(0, "arg") consts_var = pe.Var(0, "consts") jaxpr_converted = jaxpr.copy() jaxpr_converted.constvars = [] jaxpr_converted.invars = [arg_var] jaxpr_converted.eqns = ( [_unpack_eqn(arg_var, [jaxpr.invars[0], consts_var]), _unpack_eqn(consts_var, jaxpr.constvars)] + list(jaxpr.eqns)) return xla.jaxpr_computation(jaxpr_converted, (), (), c.GetShape(operand))
def _revise_cond_jaxpr(new_pval, old_pval, jaxpr, consts): new_pv, new_const = new_pval old_pv, old_const = old_pval if new_pv == old_pv: # we didn't move up the lattice by joining with the other side return jaxpr, consts elif old_pv is None: # we moved up the lattice from totally-known, so make a new jaxpr that # returns a single constant JaxTuple with elements that are constants # drawn from consts where new_pv is unknown assert not jaxpr.eqns and not consts outvar = pe.Var(0, "_cond") new_jaxpr = jaxpr.copy() new_jaxpr.constvars = [outvar] new_jaxpr.outvar = outvar new_consts = (core.pack([ core.unit if pv is None else old_c for pv, old_c in zip(new_pv, old_const) ]), ) return new_jaxpr, new_consts else: # we moved up the lattice, but not from totally-constant, so adapt the # japxr to return some new constants in places that are now unknown but # weren't before eqn = jaxpr.eqns[-1] assert eqn.primitive == core.pack_p assert len(eqn.outvars) == 1 and eqn.outvars[0] == jaxpr.outvar newvar = pe.gensym("_cond") new_constvars, new_constvals = unzip2([ (newvar(), c) for new, old, c in zip(new_pv, old_pv, old_const) if old is None and new is not None ]) new_consts = consts + tuple(new_constvals) new_jaxpr = jaxpr.copy() new_jaxpr.constvars = tuple(jaxpr.constvars) + tuple(new_constvars) newvars = iter(new_constvars) new_invars = [ next(newvars) if old is None and new is not None else (core.unitvar if new is None and old is None else v) for new, old, v in zip(new_pv, old_pv, eqn.invars) ] new_jaxpr.eqns = (list(jaxpr.eqns[:-1]) + [_pack_eqn(new_invars, jaxpr.outvar)]) return new_jaxpr, new_consts