def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext, avals_in, avals_out, *args, **params): xla_computation = xla.primitive_subcomputation(ctx.platform, ctx.axis_env, prim, *avals_in, **params) submodule_str = xc._xla.mlir.xla_computation_to_mlir_module( xla_computation) submodule = ir.Module.parse(submodule_str) callee_name = None for op in submodule.body.operations: ctx.module.body.append(op) if op.name.value == "main": callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value op.attributes["sym_visibility"] = ir.StringAttr.get("private") else: ctx.symbol_table.insert(op) output_types = map(aval_to_ir_types, avals_out) flat_output_types = util.flatten(output_types) output_type = (ir.TupleType.get_tuple(flat_output_types) if prim.multiple_results else flat_output_types[0]) call = std.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name), flatten_lowering_ir_args(args)).result if not prim.multiple_results: return [call] flat_results = [ mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result for i, typ in enumerate(flat_output_types) ] return util.unflatten(flat_results, map(len, output_types))
def _call_lowering(fn_name, stack_name, call_jaxpr, backend, ctx, avals_in, avals_out, *args): xla.check_backend_matches(backend, ctx.platform) output_types = map(aval_to_ir_types, avals_out) flat_output_types = util.flatten(output_types) sub_ctx = ctx.replace( name_stack=xla.extend_name_stack(ctx.name_stack, stack_name)) symbol_name = lower_jaxpr_to_fun(sub_ctx, fn_name, core.ClosedJaxpr(call_jaxpr, ())) call = std.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(symbol_name), flatten_lowering_ir_args(args)) return util.unflatten(call.results, map(len, output_types))
def _sharded_jit_lowering(ctx, *in_nodes, in_parts, out_parts_thunk, nparts, name, call_jaxpr, local_in_parts, local_out_parts_thunk, local_nparts): # We assume any extra leading in_nodes are constants and replicate them. num_extra_nodes = len(in_nodes) - len(in_parts) assert num_extra_nodes >= 0 in_parts = (None, ) * num_extra_nodes + in_parts args = [] for ns, sharding in safe_zip( safe_map(mlir.wrap_singleton_ir_values, in_nodes), in_parts): if sharding is not None: args.append([ mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding)) for n in ns ]) else: args.append(ns) sub_ctx = ctx.module_context.replace( name_stack=extend_name_stack(wrap_name(name, "sharded_jit"))) fn = mlir.lower_jaxpr_to_fun(sub_ctx, f"sharded_jit_{name}", core.ClosedJaxpr(call_jaxpr, ())) output_types = safe_map(mlir.aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) call = std.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(fn.name.value), mlir.flatten_lowering_ir_args(args)) out_nodes = util.unflatten(call.results, safe_map(len, output_types)) out_parts = out_parts_thunk() outputs = [] for ns, sharding in safe_zip(out_nodes, out_parts): if sharding is not None: outputs.append([ mlir.wrap_with_sharding_op(n, xla.sharding_to_proto(sharding)) for n in ns ]) else: outputs.append(ns) return outputs
def cached_lowering(ctx, *args, **params): assert ctx.primitive is not None key = (ctx.primitive, tuple(ctx.avals_in), tuple(ctx.avals_out), tuple(params.items())) try: func = ctx.module_context.cached_primitive_lowerings.get(key) except TypeError: # If the parameters aren't hashable, give up on caching. # TODO(phawkins): switch to requiring hashability, when XLA fallback # computations have been ported to MHLO. return f(ctx, *args, **params) if func is None: func = _emit_lowering_rule_as_fun(partial(f, **params), ctx) ctx.module_context.cached_primitive_lowerings[key] = func output_types = map(aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) call = std.CallOp(flat_output_types, ir.FlatSymbolRefAttr.get(func.name.value), flatten_lowering_ir_args(args)) return util.unflatten(call.results, map(len, output_types))