Exemplo n.º 1
0
def _while_loop_translation_rule(c, axis_env, *args, **kwargs):
    cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts = split_dict(
        kwargs, ["cond_jaxpr", "body_jaxpr", "cond_nconsts", "body_nconsts"])
    cond_consts, body_consts, init_vals = split_list(
        args, [cond_nconsts, body_nconsts])
    batched = bool(cond_jaxpr.out_avals[0].shape)

    # Since jaxprs don't have tuples and have multiple return values, but we need
    # the HLO While loop to take a single tuple input and output a single boolean
    # (for the cond computation) or a single tuple output (for the body
    # computation), we build XLA computations that handle the tuple munging before
    # generating a Call into the computations formed from the jaxprs.

    init_carry = c.Tuple(*(cond_consts + body_consts + init_vals))

    cond_c = xb.make_computation_builder("cond_computation")
    cond_carry = cond_c.ParameterWithShape(c.GetShape(init_carry))
    cond_carry_elts = [
        cond_c.GetTupleElement(cond_carry, i) for i in range(len(args))
    ]
    x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts])
    cond_outs = cond_c.Call(
        xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env, cond_jaxpr.literals,
                              (), *_map(cond_c.GetShape, x + z)), x + z)
    pred = cond_c.GetTupleElement(cond_outs, 0)
    if batched:
        scalar = xla_client.Shape.array_shape(onp.dtype(onp.bool_), ())
        or_ = xla.primitive_computation(lax.or_p, scalar, scalar)
        pred = cond_c.Reduce(pred, cond_c.Constant(onp.array(False)), or_,
                             list(range(cond_jaxpr.out_avals[0].ndim)))

    body_c = xb.make_computation_builder("body_computation")
    body_carry = body_c.ParameterWithShape(c.GetShape(init_carry))
    body_carry_elts = [
        body_c.GetTupleElement(body_carry, i) for i in range(len(args))
    ]
    x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts])
    body_out = body_c.Call(
        xla.jaxpr_computation(body_jaxpr.jaxpr, axis_env, body_jaxpr.literals,
                              (), *_map(body_c.GetShape, y + z)), y + z)
    new_z = [
        body_c.GetTupleElement(body_out, i) for i in range(len(init_vals))
    ]
    if batched:
        body_cond_outs = body_c.Call(
            xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env,
                                  cond_jaxpr.literals, (),
                                  *_map(body_c.GetShape, x + z)), x + z)
        body_pred = body_c.GetTupleElement(body_cond_outs, 0)
        new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z)
        assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape,
                                                    z)  # no broadcast
    new_carry = body_c.Tuple(*(x + y + new_z))

    ans = c.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry)
    ans_elts = [c.GetTupleElement(ans, i) for i in range(len(args))]
    _, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts])
    return c.Tuple(*z)
Exemplo n.º 2
0
 def test_parameter_replication(self):
     c = xb.make_computation_builder("test")
     param = xb.parameter(c, 0,
                          xc.Shape.array_shape(xc.PrimitiveType.F32, ()),
                          "", False)
     built_c = c.Build()
     assert "parameter_replication={false}" in built_c.as_hlo_text()
Exemplo n.º 3
0
 def make_computation(name, jaxpr, op_shape):
   c = xb.make_computation_builder(name)
   op = c.ParameterWithShape(op_shape)
   ops = [c.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
   outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
                            _map(c.Constant, jaxpr.literals), (), *ops)
   return c.Build(c.Tuple(*outs))
Exemplo n.º 4
0
def _nonzero_translation_rule(c, dims, avals, operands):
  (vals,), = operands
  shape = c.get_shape(vals)
  last_axis = len(shape.dimensions()) - 1
  zeros = xops.Broadcast(xb.constant(c, np.zeros((), shape.numpy_dtype())),
                         shape.dimensions())
  s32_etype = xc.dtype_to_etype(np.dtype('int32'))
  nonzero_indicators = xops.ConvertElementType(xops.Ne(vals, zeros), s32_etype)
  i = core.ShapedArray((), np.dtype('int32'))
  out_dim = xops.Reduce(c, [nonzero_indicators],
                        [xb.constant(c, np.array(0, np.dtype('int32')))],
                        xla.primitive_subcomputation(lax.add_p, i, i),
                        (last_axis,))
  c.get_shape(out_dim)  # xla type checking

  subc = xb.make_computation_builder("sort_gt_comparator")
  params = [xb.parameter(subc, i, xc.Shape.array_shape(s32_etype, ()))
            for i in range(4)]
  comparator = subc.build(xops.Gt(params[0], params[1]))
  iota_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, shape.dimensions())
  ans = xops.Sort(c, [nonzero_indicators, xops.Iota(c, iota_shape, last_axis)],
                  is_stable=True, comparator=comparator)
  _, out_val = xla.xla_destructure(c, ans)
  c.get_shape(out_val)  # xla type checking
  return [[out_dim], [out_val]]
Exemplo n.º 5
0
 def make_computation(name, jaxpr, op_shape):
     c = xb.make_computation_builder(name)
     op = c.ParameterWithShape(op_shape)
     ops = [c.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))]
     out = c.Call(
         xla.jaxpr_computation(jaxpr.jaxpr, axis_env, jaxpr.literals, (),
                               *_map(c.GetShape, ops)), ops)
     return c.Build(out)
Exemplo n.º 6
0
    def test_error_bad_consumer_id(self):
        """Try to use reserved consumer ID 0.

    Check that we get the proper error from the runtime."""
        comp = xla_bridge.make_computation_builder(self._testMethodName)
        token = hcb.xops.CreateToken(comp)
        with self.assertRaisesRegex(
                RuntimeError, "Consumer ID cannot be a reserved value: 0"):
            hcb._outfeed_receiver.receiver.add_outfeed(comp, token, 0, [
                xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))
            ])
Exemplo n.º 7
0
def _dynamic_xla_call_impl(*args, jaxpr, num_consts):
  in_dim_vals, consts, args = split_list(args, [len(jaxpr.in_dim_binders), num_consts])
  dim_in_avals = [v.aval for v in jaxpr.in_dim_binders]
  c = xb.make_computation_builder("dxla_call")
  dim_params, params = _make_params(c, dim_in_avals, map(xla.abstractify, args))
  const_params = _xla_consts(c, consts)
  dim_outs, outs = djaxpr_subcomp(c, jaxpr, dim_params, const_params + params)
  out = xops.Tuple(c, [o for ops in dim_outs + outs for o in ops])
  compiled = xb.get_backend(None).compile(c.build(out))
  result_handlers = map(result_handler, [v.aval for v in jaxpr.outs])
  out_bufcounts = [v.aval._num_buffers for v in jaxpr.outs]
  partitioner = result_partitioner(jaxpr.in_dim_binders, in_dim_vals,
                                   jaxpr.out_dims, out_bufcounts)
  return execute_compiled(compiled, partitioner, result_handlers,
                          in_dim_vals, args)
Exemplo n.º 8
0
 def test_error_different_shapes(self):
     """Try to register different shapes for the same consumer ID."""
     comp = xla_bridge.make_computation_builder(self._testMethodName)
     token = hcb.xops.CreateToken(comp)
     hcb._outfeed_receiver.receiver.add_outfeed(
         comp, token, 123,
         [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))])
     with self.assertRaisesRegex(
             RuntimeError,
             ".*does not match previous shape element_type.*"):
         hcb._outfeed_receiver.receiver.add_outfeed(
             comp, token, 123,
             [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.int32))])
     with self.assertRaisesRegex(
             RuntimeError,
             ".*does not match previous shape element_type.*"):
         hcb._outfeed_receiver.receiver.add_outfeed(
             comp, token, 123,
             [xla_bridge.constant(comp, np.zeros((2, ), dtype=np.float32))])
Exemplo n.º 9
0
 def test_parameter_replication_default(self):
     c = xb.make_computation_builder("test")
     param = xb.parameter(c, 0,
                          xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
     built_c = c.Build()
     assert "replication" not in built_c.as_hlo_text()