예제 #1
0
def _revise_cond_jaxpr(new_pval, old_pval, jaxpr, consts):
    new_pv, new_const = new_pval
    old_pv, old_const = old_pval
    if new_pv == old_pv:
        # we didn't move up the lattice by joining with the other side
        return jaxpr, consts
    elif old_pv is None:
        # we moved up the lattice from totally-known, so make a new jaxpr that
        # returns a single constant JaxTuple with elements that are constants
        # drawn from consts where new_pv is unknown
        assert not jaxpr.eqns and not consts
        outvar = pe.Var(0, "_cond")
        new_jaxpr = jaxpr.copy()
        new_jaxpr.constvars = [outvar]
        new_jaxpr.outvar = outvar
        new_consts = (core.pack([
            core.unit if pv is None else old_c
            for pv, old_c in zip(new_pv, old_const)
        ]), )
        return new_jaxpr, new_consts
    else:
        # we moved up the lattice, but not from totally-constant, so adapt the
        # japxr to return some new constants in places that are now unknown but
        # weren't before
        eqn = jaxpr.eqns[-1]
        assert eqn.primitive == core.pack_p
        assert len(eqn.outvars) == 1 and eqn.outvars[0] == jaxpr.outvar
        newvar = pe.gensym("_cond")
        new_constvars, new_constvals = unzip2([
            (newvar(), c) for new, old, c in zip(new_pv, old_pv, old_const)
            if old is None and new is not None
        ])
        new_consts = consts + tuple(new_constvals)
        new_jaxpr = jaxpr.copy()
        new_jaxpr.constvars = tuple(jaxpr.constvars) + tuple(new_constvars)
        newvars = iter(new_constvars)
        new_invars = [
            next(newvars) if old is None and new is not None else
            (core.unitvar if new is None and old is None else v)
            for new, old, v in zip(new_pv, old_pv, eqn.invars)
        ]
        new_jaxpr.eqns = (list(jaxpr.eqns[:-1]) +
                          [_pack_eqn(new_invars, jaxpr.outvar)])
        return new_jaxpr, new_consts
예제 #2
0
  out = scan_p.bind(
      core.unit, carry_ct, core.pack((ct_bs, res)),
      forward=not forward, length=length, jaxpr=jaxpr_trans)
  (ct_init, ct_consts), ct_as = out
  return ct_consts, ct_init, (ct_as, None)

def rearrange_binders(f, typed_jaxpr):
  jaxpr = typed_jaxpr.jaxpr.copy()
  jaxpr.invars = f(*jaxpr.invars)
  in_avals = f(*typed_jaxpr.in_avals)
  core.skip_checks or core.check_jaxpr(jaxpr)
  return core.TypedJaxpr(jaxpr, typed_jaxpr.literals, in_avals,
                         typed_jaxpr.out_aval)

_scan_newvar = pe.gensym('_scan')

def _move_stuff_and_add_add(typed_jaxpr):
  # jaxpr_lifted_trans :: res -> (CT c, CT b) -> (CT d, CT c, CT a)
  # jaxpr_trans :: * -> (CT c, CT d) -> (CT b, res) -> ((CT c, CT d), CT a)

  res_aval, (CTc_aval, CTb_aval) = typed_jaxpr.in_avals
  CTd_aval, CTc_aval2, CTa_aval = typed_jaxpr.out_aval
  assert CTc_aval == CTc_aval2
  in_avals = (core.AbstractTuple(()), core.AbstractTuple((CTc_aval, CTd_aval)),
              core.AbstractTuple((CTb_aval, res_aval)))
  out_aval = core.AbstractTuple((core.AbstractTuple((CTc_aval, CTd_aval)),
                                 CTa_aval))

  jaxpr = typed_jaxpr.jaxpr.copy()
  # assume the jaxpr isn't restructuring any inputs
예제 #3
0
 def __init__(self):
     self.scan_newvar = pe.gensym('_scan')