def _while_loop_translation_rule(c, init_val, cond_consts, body_consts, aval_out, cond_jaxpr, body_jaxpr): loop_carry = c.Tuple(init_val, cond_consts, body_consts) shape = c.GetShape(loop_carry) loop_carry_var = pe.Var(0, "loop_carry") outvar = pe.Var(0, "loop_carry_out") cond_var = pe.Var(0, "cond_consts") body_var = pe.Var(0, "body_consts") assert len(cond_jaxpr.invars) == 1 cond_jaxpr_converted = cond_jaxpr.copy() cond_jaxpr_converted.constvars = [] cond_jaxpr_converted.invars = [loop_carry_var] cond_jaxpr_converted.eqns = ( [_unpack_eqn(loop_carry_var, [cond_jaxpr.invars[0], cond_var, body_var]), _unpack_eqn(cond_var, cond_jaxpr.constvars)] + list(cond_jaxpr.eqns)) assert len(body_jaxpr.invars) == 1 body_jaxpr_converted = body_jaxpr.copy() body_jaxpr_converted.constvars = [] body_jaxpr_converted.invars = [loop_carry_var] body_jaxpr_converted.outvar = outvar body_jaxpr_converted.eqns = ( [_unpack_eqn(loop_carry_var, [body_jaxpr.invars[0], cond_var, body_var]), _unpack_eqn(body_var, body_jaxpr.constvars)] + list(body_jaxpr.eqns) + [_pack_eqn([body_jaxpr.outvar, cond_var, body_var], outvar)]) cond_computation = xla.jaxpr_computation(cond_jaxpr_converted, (), (), shape) body_computation = xla.jaxpr_computation(body_jaxpr_converted, (), (), shape) full_ans = c.While(cond_computation, body_computation, loop_carry) return c.GetTupleElement(full_ans, 0)
def _while_loop_translation_rule(c, axis_env, *args, **kwargs): cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts = split_dict( kwargs, ["cond_jaxpr", "body_jaxpr", "cond_nconsts", "body_nconsts"]) cond_consts, body_consts, init_vals = split_list( args, [cond_nconsts, body_nconsts]) batched = bool(cond_jaxpr.out_avals[0].shape) # Since jaxprs don't have tuples and have multiple return values, but we need # the HLO While loop to take a single tuple input and output a single boolean # (for the cond computation) or a single tuple output (for the body # computation), we build XLA computations that handle the tuple munging before # generating a Call into the computations formed from the jaxprs. init_carry = c.Tuple(*(cond_consts + body_consts + init_vals)) cond_c = xb.make_computation_builder("cond_computation") cond_carry = cond_c.ParameterWithShape(c.GetShape(init_carry)) cond_carry_elts = [ cond_c.GetTupleElement(cond_carry, i) for i in range(len(args)) ] x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts]) cond_outs = cond_c.Call( xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env, cond_jaxpr.literals, (), *_map(cond_c.GetShape, x + z)), x + z) pred = cond_c.GetTupleElement(cond_outs, 0) if batched: scalar = xla_client.Shape.array_shape(onp.dtype(onp.bool_), ()) or_ = xla.primitive_computation(lax.or_p, scalar, scalar) pred = cond_c.Reduce(pred, cond_c.Constant(onp.array(False)), or_, list(range(cond_jaxpr.out_avals[0].ndim))) body_c = xb.make_computation_builder("body_computation") body_carry = body_c.ParameterWithShape(c.GetShape(init_carry)) body_carry_elts = [ body_c.GetTupleElement(body_carry, i) for i in range(len(args)) ] x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts]) body_out = body_c.Call( xla.jaxpr_computation(body_jaxpr.jaxpr, axis_env, body_jaxpr.literals, (), *_map(body_c.GetShape, y + z)), y + z) new_z = [ body_c.GetTupleElement(body_out, i) for i in range(len(init_vals)) ] if batched: body_cond_outs = body_c.Call( xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env, cond_jaxpr.literals, (), *_map(body_c.GetShape, x + z)), x + z) body_pred = body_c.GetTupleElement(body_cond_outs, 0) new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z) assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape, z) # no broadcast new_carry = body_c.Tuple(*(x + y + new_z)) ans = c.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry) ans_elts = [c.GetTupleElement(ans, i) for i in range(len(args))] _, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts]) return c.Tuple(*z)
def make_computation(name, jaxpr, op_shape): c = xb.make_computation_builder(name) op = c.ParameterWithShape(op_shape) ops = [c.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))] out = c.Call( xla.jaxpr_computation(jaxpr.jaxpr, axis_env, jaxpr.literals, (), *_map(c.GetShape, ops)), ops) return c.Build(out)
def make_computation(jaxpr, operand): assert len(jaxpr.invars) == 1 arg_var = pe.Var(0, "arg") consts_var = pe.Var(0, "consts") jaxpr_converted = jaxpr.copy() jaxpr_converted.constvars = [] jaxpr_converted.invars = [arg_var] jaxpr_converted.eqns = ( [_unpack_eqn(arg_var, [jaxpr.invars[0], consts_var]), _unpack_eqn(consts_var, jaxpr.constvars)] + list(jaxpr.eqns)) return xla.jaxpr_computation(jaxpr_converted, (), (), c.GetShape(operand))
def xlogy_translate(c, x, y, jaxpr, aval, consts): xla_computation = xla.jaxpr_computation(jaxpr, consts, (), c.GetShape(x), c.GetShape(y)) return c.Call(xla_computation, (x, y))
def _standard_gamma_translate(c, key, alpha, jaxpr, aval, consts): xla_computation = xla.jaxpr_computation(jaxpr, consts, (), c.GetShape(key), c.GetShape(alpha)) return c.Call(xla_computation, (key, alpha))