def compile(self, imported_module: Module): """Compiles an imported module. Args: imported_module: The MLIR module consisting of funcs in the torch dialect. Returns: An opaque, backend specific module object that can be passed to load. The object may actually be something more specific to the backend (i.e. for IREE, it is a serialized VM flatbuffer) but the contract is that it is operated on by methods on this class. """ # TODO: Once transitioned to new Python API, don't reparse the module. with Context() as context: # Frontend. pm = PassManager.parse(",".join(FRONTEND_PASSES)) pm.run(imported_module) if self._debug: logging.debug("Frontend IR:{}", imported_module) # Backend. # Note that this is a separate pass manager purely to aid in debugging. pm = PassManager() self._refjit.build_backend_compilation_pipeline(pm) pm.run(imported_module) if self._debug: logging.debug("Backend IR:{}", imported_module) jit_module = self._refjit.JITModule.from_compiled_module( imported_module, refjit_backend.get_runtime_libs()) return jit_module
def compile(self, imported_module: Module): """Compiles an imported module, with a flat list of functions. The module is expected to be in linalg-on-tensors + scalar code form. TODO: More clearly define the backend contract. Generally this will extend to support globals, lists, and other stuff. Args: imported_module: The MLIR module consisting of funcs in the torch dialect. Returns: An opaque, backend specific module object that can be passed to load. The object may actually be something more specific to the backend (i.e. for IREE, it is a serialized VM flatbuffer) but the contract is that it is operated on by methods on this class. """ with imported_module.context as context: if self._debug: logging.debug("IR passed to RefJIT compiler backend:\n{}", imported_module) # Backend. # Note that this is a separate pass manager purely to aid in debugging. pm = PassManager() self._refjit.build_backend_compilation_pipeline(pm) pm.run(imported_module) if self._debug: logging.debug( "RefBackend input IR (this is what the RefBackend compiler sees):\n{}", imported_module) jit_module = self._refjit.JITModule.from_compiled_module( imported_module, refjit_backend.get_runtime_libs()) return jit_module
def compile(self, imported_module: Module): """Compiles an imported module, with a flat list of functions. The module is expected to conform to the npcomp backend contract. See the VerifyBackendContract pass for more details. Args: imported_module: The MLIR module consisting of funcs in the torch dialect. Returns: An opaque, backend specific module object that can be passed to load. The object may actually be something more specific to the backend (i.e. for IREE, it is a serialized VM flatbuffer) but the contract is that it is operated on by methods on this class. """ with imported_module.context as context: if self._debug: logging.debug("IR passed to IREE compiler backend:\n{}", imported_module) pipeline_str = "npcomp-backend-to-iree-frontend-pipeline" if self._debug: logging.debug("Running Prepare For IREE pipeline '{}'", pipeline_str) pm = PassManager.parse(pipeline_str) pm.run(imported_module) if self._debug: logging.debug( "IREE Input IR (this is what IREE's compiler will see):\n{}", imported_module) # Backend. binary = ireec.compile_str(str(imported_module), target_backends=["dylib-llvm-aot"]) return binary
def compile(self, imported_module: Module): """Compiles an imported module. Args: legacy_imported_ir_module: The MLIR module as imported from the ImportFrontend. Returns: An opaque, backend specific module object that can be passed to load. The object may actually be something more specific to the backend (i.e. for IREE, it is a serialized VM flatbuffer) but the contract is that it is operated on by methods on this class. """ with imported_module.context as context: # Frontend. if self._debug: logging.debug("Input IR:\n{}", imported_module) assert (imported_module.operation.verify() ), "Imported module does not verify" pm = PassManager.parse(",".join(FRONTEND_PASSES)) pm.run(imported_module) if self._debug: logging.debug("Frontend IR:\n{}", imported_module) # Backend. # Note that this is a separate pass manager purely to aid in debugging. pm = PassManager() self._refjit.build_backend_compilation_pipeline(pm) pm.run(imported_module) if self._debug: logging.debug("Backend IR:\n{}", imported_module) jit_module = self._refjit.JITModule.from_compiled_module( imported_module, refjit_backend.get_runtime_libs()) return jit_module
def lower_object_graph(imported_module: Module): """Lowers an imported module that has TorchScript object graph semantics. Args: imported_module: The MLIR module consisting of IR as imported by the torch_mlir.import_module. It is lowered in place. Returns: The imported_module, for convenience chaining methods. """ with imported_module.context as context: if logging.debug_enabled(): logging.debug("Initial PyTorch object graph IR:\n{}", imported_module) # Object graph lowering. pipeline_str = "torchscript-to-npcomp-backend-pipeline" if logging.debug_enabled(): logging.debug("Running Torch object graph lowering pipeline '{}'", pipeline_str) pm = PassManager.parse(pipeline_str) pm.run(imported_module) return imported_module
def compile(self, imported_module: Module): """Compiles an imported module. Args: imported_ir_module: The MLIR module as imported from the ImportFrontend. Returns: An opaque, backend specific module object that can be passed to load. The object may actually be something more specific to the backend (i.e. for IREE, it is a serialized VM flatbuffer) but the contract is that it is operated on by methods on this class. """ with imported_module.context: # Frontend. if self._debug: logging.debug("Input IR:\n{}", imported_module) assert (imported_module.operation.verify() ), "Imported module does not verify" # Frontend. pm = PassManager.parse(",".join(FRONTEND_PASSES)) pm.run(imported_module) if self._debug: logging.debug("Frontend IR:{}", imported_module) # TODO: There should be some common utility for invoking backend processes # safely (and have options like saving temps, etc). args = [ iree_backend.get_translate_exe(), "--iree-mlir-to-vm-bytecode-module" ] p = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE) imported_module.operation.print(binary=True, enable_debug_info=True, file=p.stdin) out, err = p.communicate() return out
def lower_module(imported_module: Module): """Compiles an imported module, with a flat list of functions. Args: imported_module: The MLIR module consisting of funcs and globals in the torch dialect. It is lowered in place. Returns: The imported_module, for convenience chaining methods. """ with imported_module.context as context: if logging.debug_enabled(): logging.debug("Initial PyTorch IR:\n{}", imported_module) # Frontend. pipeline_str = "torch-globalized-module-to-npcomp-backend-pipeline" if logging.debug_enabled(): logging.debug("Running Torch->backend pipeline '{}'", pipeline_str) pm = PassManager.parse(pipeline_str) pm.run(imported_module) if logging.debug_enabled(): logging.debug("Backend IR:\n{}", imported_module) return imported_module
import torch import torch_mlir import npcomp from npcomp.compiler.pytorch.backend import refjit, frontend_lowering from npcomp.compiler.utils import logging import test_utils logging.enable() torch.manual_seed(0) input = torch.rand(2, 3) mb = torch_mlir.ModuleBuilder() with mb.capture_function("cos", [input]) as f: result = torch.cos(input) f.returns([result]) backend = iree.IreeNpcompBackend() jit_module = backend.load( backend.compile(frontend_lowering.lower_module(mb.module))) logging.debug(f"Executing jit_module.cos") test_utils.compare_outputs(torch.cos, jit_module.cos, input) # This fails because ModuleBuilder represents torch.cos with a constant: # https://github.com/llvm/mlir-npcomp/issues/135 test_utils.compare_outputs(torch.cos, jit_module.cos, input + 1)