Example #1
0
 def reducer():
     c = xc.XlaBuilder("select_and_gather_pair_reducer")
     x = xla.parameter(
         c, 0, xla_client.Shape.array_shape(np.dtype(double_word_dtype),
                                            ()))
     y = xla.parameter(
         c, 1, xla_client.Shape.array_shape(np.dtype(double_word_dtype),
                                            ()))
     assert select_prim is lax.ge_p or select_prim is lax.le_p
     which = xops.Ge if select_prim is lax.ge_p else xops.Le
     xops.Select(which(fst(c, x), fst(c, y)), x, y)
     return c.build()
Example #2
0
def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes,
                                  in_parts, out_parts_thunk, nparts,
                                  name, call_jaxpr, local_in_parts,
                                  local_out_parts_thunk, local_nparts):
  subc = xc.XlaBuilder(f"sharded_jit_{name}")

  # We assume any extra leading in_nodes are constants and replicate them.
  num_extra_nodes = len(in_nodes) - len(in_parts)
  assert num_extra_nodes >= 0
  in_parts = (None,) * num_extra_nodes + in_parts

  args = []
  for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)):
    # We use xla.set_sharding instead of xla.with_sharding because inlined calls
    # shouldn't have shardings set directly on the inputs or outputs.
    arg = xla.parameter(subc, i, ctx.builder.GetShape(n))
    args.append(xla.set_sharding(subc, arg, sharding))

  sub_ctx = ctx.replace(
      builder=subc,
      name_stack=new_name_stack(wrap_name(name, "sharded_jit")))
  out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
  out_parts = out_parts_thunk()
  assert len(out_parts) == len(out_nodes)
  out_nodes = [xla.set_sharding(subc, out, sharding)
               for out, sharding in safe_zip(out_nodes, out_parts)]

  subc = subc.build(xops.Tuple(subc, out_nodes))
  return xla.xla_destructure(ctx.builder,
                             xops.Call(ctx.builder, subc, list(in_nodes)))
Example #3
0
def _comparator_builder(op_type, is_max_k):
    c = xc.XlaBuilder('top_k_{}_comparator'.format('gt' if is_max_k else 'lt'))
    p0 = xla.parameter(c, 0, xc.Shape.scalar_shape(op_type))
    p1 = xla.parameter(c, 1, xc.Shape.scalar_shape(op_type))
    xla.parameter(c, 2, xc.Shape.scalar_shape(np.dtype(np.int32)))
    xla.parameter(c, 3, xc.Shape.scalar_shape(np.dtype(np.int32)))
    if is_max_k:
        cmp_result = xc.ops.Gt(p0, p1)
    else:
        cmp_result = xc.ops.Lt(p0, p1)
    return c.build(cmp_result)
Example #4
0
 def test_parameter_replication(self):
     c = xc.XlaBuilder("test")
     _ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()),
                       "", False)
     built_c = c.Build()
     assert "parameter_replication={false}" in built_c.as_hlo_text()
Example #5
0
 def test_parameter_replication_default(self):
     c = xc.XlaBuilder("test")
     _ = xla.parameter(c, 0, xc.Shape.array_shape(xc.PrimitiveType.F32, ()))
     built_c = c.Build()
     assert "replication" not in built_c.as_hlo_text()