def _while_loop_translation_rule(c, axis_env, *args, **kwargs): cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts = split_dict( kwargs, ["cond_jaxpr", "body_jaxpr", "cond_nconsts", "body_nconsts"]) cond_consts, body_consts, init_vals = split_list( args, [cond_nconsts, body_nconsts]) batched = bool(cond_jaxpr.out_avals[0].shape) # Since jaxprs don't have tuples and have multiple return values, but we need # the HLO While loop to take a single tuple input and output a single boolean # (for the cond computation) or a single tuple output (for the body # computation), we build XLA computations that handle the tuple munging before # generating a Call into the computations formed from the jaxprs. init_carry = c.Tuple(*(cond_consts + body_consts + init_vals)) cond_c = xb.make_computation_builder("cond_computation") cond_carry = cond_c.ParameterWithShape(c.GetShape(init_carry)) cond_carry_elts = [ cond_c.GetTupleElement(cond_carry, i) for i in range(len(args)) ] x, _, z = split_list(cond_carry_elts, [cond_nconsts, body_nconsts]) cond_outs = cond_c.Call( xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env, cond_jaxpr.literals, (), *_map(cond_c.GetShape, x + z)), x + z) pred = cond_c.GetTupleElement(cond_outs, 0) if batched: scalar = xla_client.Shape.array_shape(onp.dtype(onp.bool_), ()) or_ = xla.primitive_computation(lax.or_p, scalar, scalar) pred = cond_c.Reduce(pred, cond_c.Constant(onp.array(False)), or_, list(range(cond_jaxpr.out_avals[0].ndim))) body_c = xb.make_computation_builder("body_computation") body_carry = body_c.ParameterWithShape(c.GetShape(init_carry)) body_carry_elts = [ body_c.GetTupleElement(body_carry, i) for i in range(len(args)) ] x, y, z = split_list(body_carry_elts, [cond_nconsts, body_nconsts]) body_out = body_c.Call( xla.jaxpr_computation(body_jaxpr.jaxpr, axis_env, body_jaxpr.literals, (), *_map(body_c.GetShape, y + z)), y + z) new_z = [ body_c.GetTupleElement(body_out, i) for i in range(len(init_vals)) ] if batched: body_cond_outs = body_c.Call( xla.jaxpr_computation(cond_jaxpr.jaxpr, axis_env, cond_jaxpr.literals, (), *_map(body_c.GetShape, x + z)), x + z) body_pred = body_c.GetTupleElement(body_cond_outs, 0) new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z) assert _map(body_c.GetShape, new_z) == _map(body_c.GetShape, z) # no broadcast new_carry = body_c.Tuple(*(x + y + new_z)) ans = c.While(cond_c.Build(pred), body_c.Build(new_carry), init_carry) ans_elts = [c.GetTupleElement(ans, i) for i in range(len(args))] _, _, z = split_list(ans_elts, [cond_nconsts, body_nconsts]) return c.Tuple(*z)
def test_parameter_replication(self): c = xb.make_computation_builder("test") param = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "", False) built_c = c.Build() assert "parameter_replication={false}" in built_c.as_hlo_text()
def make_computation(name, jaxpr, op_shape): c = xb.make_computation_builder(name) op = c.ParameterWithShape(op_shape) ops = [c.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))] outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env, _map(c.Constant, jaxpr.literals), (), *ops) return c.Build(c.Tuple(*outs))
def _nonzero_translation_rule(c, dims, avals, operands): (vals,), = operands shape = c.get_shape(vals) last_axis = len(shape.dimensions()) - 1 zeros = xops.Broadcast(xb.constant(c, np.zeros((), shape.numpy_dtype())), shape.dimensions()) s32_etype = xc.dtype_to_etype(np.dtype('int32')) nonzero_indicators = xops.ConvertElementType(xops.Ne(vals, zeros), s32_etype) i = core.ShapedArray((), np.dtype('int32')) out_dim = xops.Reduce(c, [nonzero_indicators], [xb.constant(c, np.array(0, np.dtype('int32')))], xla.primitive_subcomputation(lax.add_p, i, i), (last_axis,)) c.get_shape(out_dim) # xla type checking subc = xb.make_computation_builder("sort_gt_comparator") params = [xb.parameter(subc, i, xc.Shape.array_shape(s32_etype, ())) for i in range(4)] comparator = subc.build(xops.Gt(params[0], params[1])) iota_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, shape.dimensions()) ans = xops.Sort(c, [nonzero_indicators, xops.Iota(c, iota_shape, last_axis)], is_stable=True, comparator=comparator) _, out_val = xla.xla_destructure(c, ans) c.get_shape(out_val) # xla type checking return [[out_dim], [out_val]]
def make_computation(name, jaxpr, op_shape): c = xb.make_computation_builder(name) op = c.ParameterWithShape(op_shape) ops = [c.GetTupleElement(op, i) for i in range(len(jaxpr.in_avals))] out = c.Call( xla.jaxpr_computation(jaxpr.jaxpr, axis_env, jaxpr.literals, (), *_map(c.GetShape, ops)), ops) return c.Build(out)
def test_error_bad_consumer_id(self): """Try to use reserved consumer ID 0. Check that we get the proper error from the runtime.""" comp = xla_bridge.make_computation_builder(self._testMethodName) token = hcb.xops.CreateToken(comp) with self.assertRaisesRegex( RuntimeError, "Consumer ID cannot be a reserved value: 0"): hcb._outfeed_receiver.receiver.add_outfeed(comp, token, 0, [ xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32)) ])
def _dynamic_xla_call_impl(*args, jaxpr, num_consts): in_dim_vals, consts, args = split_list(args, [len(jaxpr.in_dim_binders), num_consts]) dim_in_avals = [v.aval for v in jaxpr.in_dim_binders] c = xb.make_computation_builder("dxla_call") dim_params, params = _make_params(c, dim_in_avals, map(xla.abstractify, args)) const_params = _xla_consts(c, consts) dim_outs, outs = djaxpr_subcomp(c, jaxpr, dim_params, const_params + params) out = xops.Tuple(c, [o for ops in dim_outs + outs for o in ops]) compiled = xb.get_backend(None).compile(c.build(out)) result_handlers = map(result_handler, [v.aval for v in jaxpr.outs]) out_bufcounts = [v.aval._num_buffers for v in jaxpr.outs] partitioner = result_partitioner(jaxpr.in_dim_binders, in_dim_vals, jaxpr.out_dims, out_bufcounts) return execute_compiled(compiled, partitioner, result_handlers, in_dim_vals, args)
def test_error_different_shapes(self): """Try to register different shapes for the same consumer ID.""" comp = xla_bridge.make_computation_builder(self._testMethodName) token = hcb.xops.CreateToken(comp) hcb._outfeed_receiver.receiver.add_outfeed( comp, token, 123, [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.float32))]) with self.assertRaisesRegex( RuntimeError, ".*does not match previous shape element_type.*"): hcb._outfeed_receiver.receiver.add_outfeed( comp, token, 123, [xla_bridge.constant(comp, np.zeros((2, 3), dtype=np.int32))]) with self.assertRaisesRegex( RuntimeError, ".*does not match previous shape element_type.*"): hcb._outfeed_receiver.receiver.add_outfeed( comp, token, 123, [xla_bridge.constant(comp, np.zeros((2, ), dtype=np.float32))])
def test_parameter_replication_default(self): c = xb.make_computation_builder("test") param = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ())) built_c = c.Build() assert "replication" not in built_c.as_hlo_text()