Пример #1
0
def _while_loop_translation_rule(c, axis_env, init_val, cond_consts,
                                 body_consts, aval_out, cond_jaxpr, body_jaxpr):
  loop_carry = c.Tuple(init_val, cond_consts, body_consts)
  shape = c.GetShape(loop_carry)

  loop_carry_var = pe.Var(0, "loop_carry")
  outvar = pe.Var(0, "loop_carry_out")
  cond_var = pe.Var(0, "cond_consts")
  body_var = pe.Var(0, "body_consts")

  assert len(cond_jaxpr.invars) == 1
  cond_jaxpr_converted = cond_jaxpr.copy()
  cond_jaxpr_converted.constvars = []
  cond_jaxpr_converted.invars = [loop_carry_var]
  cond_jaxpr_converted.eqns = (
      [_unpack_eqn(loop_carry_var, [cond_jaxpr.invars[0], cond_var, body_var]),
       _unpack_eqn(cond_var, cond_jaxpr.constvars)]
      + list(cond_jaxpr.eqns))

  assert len(body_jaxpr.invars) == 1
  body_jaxpr_converted = body_jaxpr.copy()
  body_jaxpr_converted.constvars = []
  body_jaxpr_converted.invars = [loop_carry_var]
  body_jaxpr_converted.outvar = outvar
  body_jaxpr_converted.eqns = (
      [_unpack_eqn(loop_carry_var, [body_jaxpr.invars[0], cond_var, body_var]),
       _unpack_eqn(body_var, body_jaxpr.constvars)]
      + list(body_jaxpr.eqns) +
      [_pack_eqn([body_jaxpr.outvar, cond_var, body_var], outvar)])

  cond_c = xla._jaxpr_computation(cond_jaxpr_converted, axis_env, (), (), shape)
  body_c = xla._jaxpr_computation(body_jaxpr_converted, axis_env, (), (), shape)
  full_ans = c.While(cond_c, body_c, loop_carry)
  return c.GetTupleElement(full_ans, 0)
Пример #2
0
 def make_computation(jaxpr, operand):
   assert len(jaxpr.invars) == 1
   arg_var = pe.Var(0, "arg")
   consts_var = pe.Var(0, "consts")
   jaxpr_converted = jaxpr.copy()
   jaxpr_converted.constvars = []
   jaxpr_converted.invars = [arg_var]
   jaxpr_converted.eqns = (
       [_unpack_eqn(arg_var, [jaxpr.invars[0], consts_var]),
       _unpack_eqn(consts_var, jaxpr.constvars)]
       + list(jaxpr.eqns))
   return xla._jaxpr_computation(jaxpr_converted, axis_env, (), (),
                                 c.GetShape(operand))