def _optimization_barrier_lowering_rule(ctx, *args): barrier_types = _map(mlir.aval_to_ir_types, ctx.avals_in) flat_barrier_types = util.flatten(barrier_types) flat_args = mlir.flatten_lowering_ir_args(args) barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args) return util.unflatten(barrier_op.results, _map(len, barrier_types))
def function_effect_lowering(ctx, *, effect): def _f(ctx): ctx.set_tokens_out(ctx.tokens_in) return [] func = mlir._emit_lowering_rule_as_fun(_f, ctx) output_types = map(mlir.aval_to_ir_types, ctx.avals_out) token_types = [mlir.token_type() for _ in ctx.tokens_in.items()] output_types = [*token_types, *output_types] flat_output_types = util.flatten(output_types) call = mlir.func_dialect.CallOp(flat_output_types, mlir.ir.FlatSymbolRefAttr.get(func.name.value), mlir.flatten_lowering_ir_args(ctx.tokens_in.tokens())) tokens, out = util.split_list(call.results, [len(ctx.tokens_in)]) ctx.set_tokens_out(mlir.TokenSet(zip(ctx.tokens_in.effects(), tokens))) return out
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): # We assume any extra leading in_nodes are constants and replicate them. num_extra_nodes = len(in_nodes) - len(in_parts) assert num_extra_nodes >= 0 in_parts = (None, ) * num_extra_nodes + in_parts args = [] for ns, sharding in safe_zip( safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts): if sharding is not None: args.append([ mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding)) for n in ns ]) else: args.append(ns) sub_ctx = ctx.module_context.replace( name_stack=extend_name_stack(wrap_name(name, "sharded_jit"))) fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}", core.ClosedJaxpr(call_jaxpr, ())) output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) call = std.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(fn.name.value), mlir.flatten_lowering_ir_args(args)) out_nodes = util.unflatten(call.results, safe_map(len, output_types)) out_parts = out_parts_thunk() outputs = [] for ns, sharding in safe_zip(out_nodes, out_parts): if sharding is not None: outputs.append([ mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding)) for n in ns ]) else: outputs.append(ns) return outputs