def compile_or_get_cached(backend, computation, compile_options): # Avoid import cycle between jax and jax.experimental from jax.experimental.compilation_cache import compilation_cache as cc if isinstance(computation, ir.Module): sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value computation = mlir.module_to_string(computation) else: module_name = computation.name() # Persistent compilation cache only implemented on TPU. # TODO(skye): add warning when initializing cache on unsupported default platform if cc.is_initialized() and backend.platform == 'tpu': cached_executable = cc.get_executable(computation, compile_options, backend) if cached_executable is not None: logging.info('Persistent compilation cache hit for %s.', module_name) return cached_executable else: compiled = backend_compile(backend, computation, compile_options) cc.put_executable(module_name, computation, compile_options, compiled, backend) return compiled if FLAGS.jax_dump_ir_to: ir_str = (computation if isinstance(computation, str) else computation.as_hlo_text()) _dump_ir_to_file(module_name, ir_str) return backend_compile(backend, computation, compile_options)
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext, avals_in, avals_out, *args, **params): xla_computation = xla.primitive_subcomputation(ctx.platform, ctx.axis_env, prim, *avals_in, **params) submodule_str = xc._xla.mlir.xla_computation_to_mlir_module( xla_computation) submodule = ir.Module.parse(submodule_str) callee_name = None for op in submodule.body.operations: ctx.module.body.append(op) if op.name.value == "main": callee_name = ir.StringAttr(ctx.symbol_table.insert(op)).value op.attributes["sym_visibility"] = ir.StringAttr.get("private") else: ctx.symbol_table.insert(op) output_types = map(aval_to_ir_types, avals_out) flat_output_types = util.flatten(output_types) output_type = (ir.TupleType.get_tuple(flat_output_types) if prim.multiple_results else flat_output_types[0]) call = std.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name), flatten_lowering_ir_args(args)).result if not prim.multiple_results: return [call] flat_results = [ mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result for i, typ in enumerate(flat_output_types) ] return util.unflatten(flat_results, map(len, output_types))
def fallback(ctx: LoweringRuleContext, *args, **params): module_ctx = ctx.module_context xla_computation = xla.primitive_subcomputation(module_ctx.platform, module_ctx.axis_env, prim, *ctx.avals_in, **params) submodule_str = xc._xla.mlir.xla_computation_to_mlir_module( xla_computation) submodule = ir.Module.parse(submodule_str) callee_name = None for op in submodule.body.operations: op = typing.cast(FuncOpType, op) module_ctx.module.body.append(op) if op.name.value == "main": op.attributes["sym_name"] = ir.StringAttr.get( f"xla_fallback_{prim.name}") callee_name = ir.StringAttr( module_ctx.symbol_table.insert(op)).value op.attributes["sym_visibility"] = ir.StringAttr.get("private") else: module_ctx.symbol_table.insert(op) output_types = map(aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) output_type = (ir.TupleType.get_tuple(flat_output_types) if prim.multiple_results else flat_output_types[0]) call = func_dialect.CallOp([output_type], ir.FlatSymbolRefAttr.get(callee_name), flatten_lowering_ir_args(args)).result if not prim.multiple_results: return [call] if jax._src.lib.mlir_api_version < 6: flat_results = [ mhlo.GetTupleElementOp(typ, call, i32_attr(i)).result for i, typ in enumerate(flat_output_types) ] else: flat_results = [ mhlo.GetTupleElementOp(call, i32_attr(i)).result for i in range(len(flat_output_types)) ] return util.unflatten(flat_results, map(len, output_types))
def compile_or_get_cached(backend, computation, compile_options): # Avoid import cycle between jax and jax.experimental from jax.experimental.compilation_cache import compilation_cache as cc if isinstance(computation, ir.Module): sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value # Convert ir.Module to str representation (the default), unless the # back-end expliclity flags the ability to handle a module directly # (avoiding the overhead of back and forth conversions) if getattr(backend, "needs_str_ir", True): computation = mlir.module_to_string(computation) else: module_name = computation.name() # Persistent compilation cache only implemented on TPU. # TODO(skye): add warning when initializing cache on unsupported default platform if cc.is_initialized() and backend.platform == 'tpu': cached_executable = cc.get_executable(computation, compile_options, backend) if cached_executable is not None: logging.info('Persistent compilation cache hit for %s.', module_name) return cached_executable else: compiled = backend_compile(backend, computation, compile_options) cc.put_executable(module_name, computation, compile_options, compiled, backend) return compiled if FLAGS.jax_dump_ir_to: if isinstance(computation, xc.XlaComputation): ir_str = computation.as_hlo_text() elif isinstance(computation, ir.Module): ir_str = mlir.module_to_string(computation) else: assert isinstance(computation, str) ir_str = computation _dump_ir_to_file(module_name, ir_str) return backend_compile(backend, computation, compile_options)
def lower_jaxpr_to_fun(ctx: LoweringContext, name: str, jaxpr: core.ClosedJaxpr, *, public: bool = False, replace_units_with_dummy: bool = False, replace_tokens_with_dummy: bool = False) -> str: """Lowers jaxpr and its callees to an IR function. Assumes that an MLIR context, location, and insertion point are set. Args: ctx: the lowering context. name: the function name. The name will be uniquified by the symbol table, so it is ok to use the same name multiple times. jaxpr: the jaxpr to lower. public: if true, the function's visibility is set to "public". replace_units_with_dummy: if true, unit arguments/return values are replaced with bool arrays of size [0]. replace_tokens_with_dummy: if true, token arguments/return values are replaced with bool arrays of size [0]. Returns the name of the function. """ def aval_to_types(aval): if replace_units_with_dummy and aval is core.abstract_unit: aval = core.ShapedArray((), np.dtype(np.bool_)) elif replace_tokens_with_dummy and aval is core.abstract_token: aval = core.ShapedArray((), np.dtype(np.bool_)) return aval_to_ir_types(aval) input_types = map(aval_to_types, jaxpr.in_avals) output_types = map(aval_to_types, jaxpr.out_avals) flat_input_types = util.flatten(input_types) flat_output_types = util.flatten(output_types) ftype = ir.FunctionType.get(flat_input_types, flat_output_types) func_op = builtin.FuncOp(name, ftype, ip=ctx.ip) func_op.attributes["sym_visibility"] = ir.StringAttr.get( "public" if public else "private") symbol_name = ir.StringAttr(ctx.symbol_table.insert(func_op)).value entry_block = func_op.add_entry_block() with ir.InsertionPoint(entry_block): unflattened_args = util.unflatten(entry_block.arguments, map(len, input_types)) args: List[List[ir.Value]] = [] for aval, arg in zip(jaxpr.in_avals, unflattened_args): if replace_units_with_dummy and aval is core.abstract_unit: args.append([]) elif replace_tokens_with_dummy and aval is core.abstract_token: args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results) else: args.append(arg) callee_name_stack = xla.extend_name_stack(ctx.name_stack, xla.wrap_name(name, 'jit')) out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack), jaxpr.jaxpr, map(ir_constants, jaxpr.consts), *args) outs = [] for aval, out in zip(jaxpr.out_avals, out_vals): if replace_units_with_dummy and aval is core.abstract_unit: outs.append(ir_constants(np.zeros((), np.bool_))) elif replace_tokens_with_dummy and aval is core.abstract_token: outs.append(ir_constants(np.zeros((), np.bool_))) else: outs.append(out) std.ReturnOp(util.flatten(outs)) return symbol_name