def _revise_cond_jaxpr(new_pval, old_pval, jaxpr, consts): new_pv, new_const = new_pval old_pv, old_const = old_pval if new_pv == old_pv: # we didn't move up the lattice by joining with the other side return jaxpr, consts elif old_pv is None: # we moved up the lattice from totally-known, so make a new jaxpr that # returns a single constant JaxTuple with elements that are constants # drawn from consts where new_pv is unknown assert not jaxpr.eqns and not consts outvar = pe.Var(0, "_cond") new_jaxpr = jaxpr.copy() new_jaxpr.constvars = [outvar] new_jaxpr.outvar = outvar new_consts = (core.pack([ core.unit if pv is None else old_c for pv, old_c in zip(new_pv, old_const) ]), ) return new_jaxpr, new_consts else: # we moved up the lattice, but not from totally-constant, so adapt the # japxr to return some new constants in places that are now unknown but # weren't before eqn = jaxpr.eqns[-1] assert eqn.primitive == core.pack_p assert len(eqn.outvars) == 1 and eqn.outvars[0] == jaxpr.outvar newvar = pe.gensym("_cond") new_constvars, new_constvals = unzip2([ (newvar(), c) for new, old, c in zip(new_pv, old_pv, old_const) if old is None and new is not None ]) new_consts = consts + tuple(new_constvals) new_jaxpr = jaxpr.copy() new_jaxpr.constvars = tuple(jaxpr.constvars) + tuple(new_constvars) newvars = iter(new_constvars) new_invars = [ next(newvars) if old is None and new is not None else (core.unitvar if new is None and old is None else v) for new, old, v in zip(new_pv, old_pv, eqn.invars) ] new_jaxpr.eqns = (list(jaxpr.eqns[:-1]) + [_pack_eqn(new_invars, jaxpr.outvar)]) return new_jaxpr, new_consts
out = scan_p.bind( core.unit, carry_ct, core.pack((ct_bs, res)), forward=not forward, length=length, jaxpr=jaxpr_trans) (ct_init, ct_consts), ct_as = out return ct_consts, ct_init, (ct_as, None) def rearrange_binders(f, typed_jaxpr): jaxpr = typed_jaxpr.jaxpr.copy() jaxpr.invars = f(*jaxpr.invars) in_avals = f(*typed_jaxpr.in_avals) core.skip_checks or core.check_jaxpr(jaxpr) return core.TypedJaxpr(jaxpr, typed_jaxpr.literals, in_avals, typed_jaxpr.out_aval) _scan_newvar = pe.gensym('_scan') def _move_stuff_and_add_add(typed_jaxpr): # jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a) # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a) res_aval, (CTc_aval, CTb_aval) = typed_jaxpr.in_avals CTd_aval, CTc_aval2, CTa_aval = typed_jaxpr.out_aval assert CTc_aval == CTc_aval2 in_avals = (core.AbstractTuple(()), core.AbstractTuple((CTc_aval, CTd_aval)), core.AbstractTuple((CTb_aval, res_aval))) out_aval = core.AbstractTuple((core.AbstractTuple((CTc_aval, CTd_aval)), CTa_aval)) jaxpr = typed_jaxpr.jaxpr.copy() # assume the jaxpr isn't restructuring any inputs
def __init__(self): self.scan_newvar = pe.gensym('_scan')