def _sharded_jit_translation_rule(ctx, avals_in, avals_out, *in_nodes, in_parts, out_parts_thunk, nparts, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): subc = xc.XlaBuilder(f"sharded_jit_{name}") # We assume any extra leading in_nodes are constants and replicate them. num_extra_nodes = len(in_nodes) - len(in_parts) assert num_extra_nodes >= 0 in_parts = (None,) * num_extra_nodes + in_parts args = [] for i, (n, sharding) in enumerate(safe_zip(in_nodes, in_parts)): # We use xla.set_sharding instead of xla.with_sharding because inlined calls # shouldn't have shardings set directly on the inputs or outputs. arg = xla.parameter(subc, i, ctx.builder.GetShape(n)) args.append(xla.set_sharding(subc, arg, sharding)) sub_ctx = ctx.replace( builder=subc, name_stack=new_name_stack(wrap_name(name, "sharded_jit"))) out_nodes = xla.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args) out_parts = out_parts_thunk() assert len(out_parts) == len(out_nodes) out_nodes = [xla.set_sharding(subc, out, sharding) for out, sharding in safe_zip(out_nodes, out_parts)] subc = subc.build(xops.Tuple(subc, out_nodes)) return xla.xla_destructure(ctx.builder, xops.Call(ctx.builder, subc, list(in_nodes)))
def _sharding_constraint_translation_rule(ctx, avals_in, avals_out, x_node, partitions): return [xla.set_sharding(ctx.builder, x_node, partitions)]