示例#1
0
def _while_loop_translation_rule(c, init_val, cond_consts, body_consts,
                                 aval_out, cond_jaxpr, body_jaxpr):
  loop_carry = c.Tuple(init_val, cond_consts, body_consts)
  shape = c.GetShape(loop_carry)

  loop_carry_var = pe.Var(0, "loop_carry")
  outvar = pe.Var(0, "loop_carry_out")
  cond_var = pe.Var(0, "cond_consts")
  body_var = pe.Var(0, "body_consts")

  assert len(cond_jaxpr.invars) == 1
  cond_jaxpr_converted = cond_jaxpr.copy()
  cond_jaxpr_converted.constvars = []
  cond_jaxpr_converted.invars = [loop_carry_var]
  cond_jaxpr_converted.eqns = (
      [_unpack_eqn(loop_carry_var, [cond_jaxpr.invars[0], cond_var, body_var]),
       _unpack_eqn(cond_var, cond_jaxpr.constvars)]
      + list(cond_jaxpr.eqns))

  assert len(body_jaxpr.invars) == 1
  body_jaxpr_converted = body_jaxpr.copy()
  body_jaxpr_converted.constvars = []
  body_jaxpr_converted.invars = [loop_carry_var]
  body_jaxpr_converted.outvar = outvar
  body_jaxpr_converted.eqns = (
      [_unpack_eqn(loop_carry_var, [body_jaxpr.invars[0], cond_var, body_var]),
       _unpack_eqn(body_var, body_jaxpr.constvars)]
      + list(body_jaxpr.eqns) +
      [_pack_eqn([body_jaxpr.outvar, cond_var, body_var], outvar)])

  cond_computation = xla.jaxpr_computation(cond_jaxpr_converted, (), (), shape)
  body_computation = xla.jaxpr_computation(body_jaxpr_converted, (), (), shape)
  full_ans = c.While(cond_computation, body_computation, loop_carry)
  return c.GetTupleElement(full_ans, 0)
示例#2
0
 def make_computation(jaxpr, operand):
   assert len(jaxpr.invars) == 1
   arg_var = pe.Var(0, "arg")
   consts_var = pe.Var(0, "consts")
   jaxpr_converted = jaxpr.copy()
   jaxpr_converted.constvars = []
   jaxpr_converted.invars = [arg_var]
   jaxpr_converted.eqns = (
       [_unpack_eqn(arg_var, [jaxpr.invars[0], consts_var]),
       _unpack_eqn(consts_var, jaxpr.constvars)]
       + list(jaxpr.eqns))
   return xla.jaxpr_computation(jaxpr_converted, (), (), c.GetShape(operand))
示例#3
0
def _revise_cond_jaxpr(new_pval, old_pval, jaxpr, consts):
    new_pv, new_const = new_pval
    old_pv, old_const = old_pval
    if new_pv == old_pv:
        # we didn't move up the lattice by joining with the other side
        return jaxpr, consts
    elif old_pv is None:
        # we moved up the lattice from totally-known, so make a new jaxpr that
        # returns a single constant JaxTuple with elements that are constants
        # drawn from consts where new_pv is unknown
        assert not jaxpr.eqns and not consts
        outvar = pe.Var(0, "_cond")
        new_jaxpr = jaxpr.copy()
        new_jaxpr.constvars = [outvar]
        new_jaxpr.outvar = outvar
        new_consts = (core.pack([
            core.unit if pv is None else old_c
            for pv, old_c in zip(new_pv, old_const)
        ]), )
        return new_jaxpr, new_consts
    else:
        # we moved up the lattice, but not from totally-constant, so adapt the
        # japxr to return some new constants in places that are now unknown but
        # weren't before
        eqn = jaxpr.eqns[-1]
        assert eqn.primitive == core.pack_p
        assert len(eqn.outvars) == 1 and eqn.outvars[0] == jaxpr.outvar
        newvar = pe.gensym("_cond")
        new_constvars, new_constvals = unzip2([
            (newvar(), c) for new, old, c in zip(new_pv, old_pv, old_const)
            if old is None and new is not None
        ])
        new_consts = consts + tuple(new_constvals)
        new_jaxpr = jaxpr.copy()
        new_jaxpr.constvars = tuple(jaxpr.constvars) + tuple(new_constvars)
        newvars = iter(new_constvars)
        new_invars = [
            next(newvars) if old is None and new is not None else
            (core.unitvar if new is None and old is None else v)
            for new, old, v in zip(new_pv, old_pv, eqn.invars)
        ]
        new_jaxpr.eqns = (list(jaxpr.eqns[:-1]) +
                          [_pack_eqn(new_invars, jaxpr.outvar)])
        return new_jaxpr, new_consts