def _while_loop_translation_rule(c, axis_env, init_val, cond_consts, body_consts, aval_out, cond_jaxpr, body_jaxpr): loop_carry = c.Tuple(init_val, cond_consts, body_consts) shape = c.GetShape(loop_carry) loop_carry_var = pe.Var(0, "loop_carry") outvar = pe.Var(0, "loop_carry_out") cond_var = pe.Var(0, "cond_consts") body_var = pe.Var(0, "body_consts") assert len(cond_jaxpr.invars) == 1 cond_jaxpr_converted = cond_jaxpr.copy() cond_jaxpr_converted.constvars = [] cond_jaxpr_converted.invars = [loop_carry_var] cond_jaxpr_converted.eqns = ( [_unpack_eqn(loop_carry_var, [cond_jaxpr.invars[0], cond_var, body_var]), _unpack_eqn(cond_var, cond_jaxpr.constvars)] + list(cond_jaxpr.eqns)) assert len(body_jaxpr.invars) == 1 body_jaxpr_converted = body_jaxpr.copy() body_jaxpr_converted.constvars = [] body_jaxpr_converted.invars = [loop_carry_var] body_jaxpr_converted.outvar = outvar body_jaxpr_converted.eqns = ( [_unpack_eqn(loop_carry_var, [body_jaxpr.invars[0], cond_var, body_var]), _unpack_eqn(body_var, body_jaxpr.constvars)] + list(body_jaxpr.eqns) + [_pack_eqn([body_jaxpr.outvar, cond_var, body_var], outvar)]) cond_c = xla._jaxpr_computation(cond_jaxpr_converted, axis_env, (), (), shape) body_c = xla._jaxpr_computation(body_jaxpr_converted, axis_env, (), (), shape) full_ans = c.While(cond_c, body_c, loop_carry) return c.GetTupleElement(full_ans, 0)
def make_computation(jaxpr, operand): assert len(jaxpr.invars) == 1 arg_var = pe.Var(0, "arg") consts_var = pe.Var(0, "consts") jaxpr_converted = jaxpr.copy() jaxpr_converted.constvars = [] jaxpr_converted.invars = [arg_var] jaxpr_converted.eqns = ( [_unpack_eqn(arg_var, [jaxpr.invars[0], consts_var]), _unpack_eqn(consts_var, jaxpr.constvars)] + list(jaxpr.eqns)) return xla._jaxpr_computation(jaxpr_converted, axis_env, (), (), c.GetShape(operand))