def test_parameter_replication(self): c = xb.make_computation_builder("test") param = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()), "", False) built_c = c.Build() assert "parameter_replication={false}" in built_c.as_hlo_text()
def _nonzero_translation_rule(c, dims, avals, operands): (vals,), = operands shape = c.get_shape(vals) last_axis = len(shape.dimensions()) - 1 zeros = xops.Broadcast(xb.constant(c, np.zeros((), shape.numpy_dtype())), shape.dimensions()) s32_etype = xc.dtype_to_etype(np.dtype('int32')) nonzero_indicators = xops.ConvertElementType(xops.Ne(vals, zeros), s32_etype) i = core.ShapedArray((), np.dtype('int32')) out_dim = xops.Reduce(c, [nonzero_indicators], [xb.constant(c, np.array(0, np.dtype('int32')))], xla.primitive_subcomputation(lax.add_p, i, i), (last_axis,)) c.get_shape(out_dim) # xla type checking subc = xb.make_computation_builder("sort_gt_comparator") params = [xb.parameter(subc, i, xc.Shape.array_shape(s32_etype, ())) for i in range(4)] comparator = subc.build(xops.Gt(params[0], params[1])) iota_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, shape.dimensions()) ans = xops.Sort(c, [nonzero_indicators, xops.Iota(c, iota_shape, last_axis)], is_stable=True, comparator=comparator) _, out_val = xla.xla_destructure(c, ans) c.get_shape(out_val) # xla type checking return [[out_dim], [out_val]]
def test_parameter_replication_default(self): c = xb.make_computation_builder("test") param = xb.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ())) built_c = c.Build() assert "replication" not in built_c.as_hlo_text()
def _make_params(c, dim_in_avals, in_avals): n = it.count() make = lambda a: [xb.parameter(c, next(n), s) for s in xla.aval_to_xla_shapes(a)] return map(make, dim_in_avals), map(make, in_avals)