Esempio n. 1
0
def _while_loop_translation_rule(c, init_val, cond_consts, body_consts,
                                 aval_out, cond_jaxpr, body_jaxpr):
  loop_carry = c.Tuple(init_val, cond_consts, body_consts)
  shape = c.GetShape(loop_carry)

  loop_carry_var = pe.Var(0, "loop_carry")
  outvar = pe.Var(0, "loop_carry_out")
  cond_var = pe.Var(0, "cond_consts")
  body_var = pe.Var(0, "body_consts")

  assert len(cond_jaxpr.invars) == 1
  cond_jaxpr_converted = cond_jaxpr.copy()
  cond_jaxpr_converted.constvars = []
  cond_jaxpr_converted.invars = [loop_carry_var]
  cond_jaxpr_converted.eqns = (
      [_unpack_eqn(loop_carry_var, [cond_jaxpr.invars[0], cond_var, body_var]),
       _unpack_eqn(cond_var, cond_jaxpr.constvars)]
      + list(cond_jaxpr.eqns))

  assert len(body_jaxpr.invars) == 1
  body_jaxpr_converted = body_jaxpr.copy()
  body_jaxpr_converted.constvars = []
  body_jaxpr_converted.invars = [loop_carry_var]
  body_jaxpr_converted.outvar = outvar
  body_jaxpr_converted.eqns = (
      [_unpack_eqn(loop_carry_var, [body_jaxpr.invars[0], cond_var, body_var]),
       _unpack_eqn(body_var, body_jaxpr.constvars)]
      + list(body_jaxpr.eqns) +
      [_pack_eqn([body_jaxpr.outvar, cond_var, body_var], outvar)])

  cond_computation = xla.jaxpr_computation(cond_jaxpr_converted, (), (), shape)
  body_computation = xla.jaxpr_computation(body_jaxpr_converted, (), (), shape)
  full_ans = c.While(cond_computation, body_computation, loop_carry)
  return c.GetTupleElement(full_ans, 0)
Esempio n. 2
0
def _while_loop_translation_rule(c, axis_env, *args, **kwargs):
    cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts = split_dict(
        kwargs, ["cond_jaxpr", "body_jaxpr", "cond_nconsts", "body_nconsts"])
    cond_consts, body_consts, init_vals = split_list(
        args, [cond_nconsts, body_nconsts])
    batched = bool(cond_jaxpr.out_avals[0].shape)

    # Since jaxprs don't have tuples and have multiple return values, but we need
    # the HLO While loop to take a single tuple input and output a single boolean
    # (for the cond computation) or a single tuple output (for the body
    # computation), we build XLA computations that handle the tuple munging before
    # generating a Call into the computations formed from the jaxprs.

    init_carry = c.Tuple(*(cond_consts + body_consts + init_vals))

    cond_c = xb.make_computation_builder("cond_computation")
    cond_carry = cond_c.ParameterWithShape(c.GetShape(init_carry))
    cond_carry_elts = [
        cond_c.GetTupleElement(cond_carry, i) for i in range(len(args))
    ]
    x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
    cond_outs = cond_c.Call(
        xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env, cond_jaxpr.literals,
                              (), *_map(cond_c.GetShape, x + z)), x + z)
    pred = cond_c.GetTupleElement(cond_outs, 0)
    if batched:
        scalar = xla_client.Shape.array_shape(onp.dtype(onp.bool_), ())
        or_ = xla.primitive_computation(lax.or_p, scalar, scalar)
        pred = cond_c.Reduce(pred, cond_c.Constant(onp.array(False)), or_,
                             list(range(cond_jaxpr.out_avals[0].ndim)))

    body_c = xb.make_computation_builder("body_computation")
    body_carry = body_c.ParameterWithShape(c.GetShape(init_carry))
    body_carry_elts = [
        body_c.GetTupleElement(body_carry, i) for i in range(len(args))
    ]
    x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
    body_out = body_c.Call(
        xla.jaxpr_computation(body_jaxpr.jaxpr, axis_env, body_jaxpr.literals,
                              (), *_map(body_c.GetShape, y + z)), y + z)
    new_z = [
        body_c.GetTupleElement(body_out, i) for i in range(len(init_vals))
    ]
    if batched:
        body_cond_outs = body_c.Call(
            xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env,
                                  cond_jaxpr.literals, (),
                                  *_map(body_c.GetShape, x + z)), x + z)
        body_pred = body_c.GetTupleElement(body_cond_outs, 0)
        new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z)
        assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape,
                                                    z)  # no broadcast
    new_carry = body_c.Tuple(*(x + y + new_z))

    ans = c.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry)
    ans_elts = [c.GetTupleElement(ans, i) for i in range(len(args))]
    _, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
    return c.Tuple(*z)
Esempio n. 3
0
 def make_computation(name, jaxpr, op_shape):
     c = xb.make_computation_builder(name)
     op = c.ParameterWithShape(op_shape)
     ops = [c.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
     out = c.Call(
         xla.jaxpr_computation(jaxpr.jaxpr, axis_env, jaxpr.literals, (),
                               *_map(c.GetShape, ops)), ops)
     return c.Build(out)
Esempio n. 4
0
 def make_computation(jaxpr, operand):
   assert len(jaxpr.invars) == 1
   arg_var = pe.Var(0, "arg")
   consts_var = pe.Var(0, "consts")
   jaxpr_converted = jaxpr.copy()
   jaxpr_converted.constvars = []
   jaxpr_converted.invars = [arg_var]
   jaxpr_converted.eqns = (
       [_unpack_eqn(arg_var, [jaxpr.invars[0], consts_var]),
       _unpack_eqn(consts_var, jaxpr.constvars)]
       + list(jaxpr.eqns))
   return xla.jaxpr_computation(jaxpr_converted, (), (), c.GetShape(operand))
Esempio n. 5
0
def xlogy_translate(c, x, y, jaxpr, aval, consts):
    xla_computation = xla.jaxpr_computation(jaxpr, consts, (), c.GetShape(x),
                                            c.GetShape(y))
    return c.Call(xla_computation, (x, y))
Esempio n. 6
0
def _standard_gamma_translate(c, key, alpha, jaxpr, aval, consts):
    xla_computation = xla.jaxpr_computation(jaxpr, consts, (), c.GetShape(key),
                                            c.GetShape(alpha))
    return c.Call(xla_computation, (key, alpha))