示例#1
0
    def compile(self, imported_module: Module):
        """Compiles an imported module.

    Args:
      imported_module: The MLIR module consisting of funcs in the torch
        dialect.
    Returns:
      An opaque, backend specific module object that can be passed to load.
      The object may actually be something more specific to the backend (i.e.
      for IREE, it is a serialized VM flatbuffer) but the contract is that
      it is operated on by methods on this class.
    """
        # TODO: Once transitioned to new Python API, don't reparse the module.
        with Context() as context:
            # Frontend.
            pm = PassManager.parse(",".join(FRONTEND_PASSES))
            pm.run(imported_module)
            if self._debug:
                logging.debug("Frontend IR:{}", imported_module)

            # Backend.
            # Note that this is a separate pass manager purely to aid in debugging.
            pm = PassManager()
            self._refjit.build_backend_compilation_pipeline(pm)
            pm.run(imported_module)
            if self._debug:
                logging.debug("Backend IR:{}", imported_module)

        jit_module = self._refjit.JITModule.from_compiled_module(
            imported_module, refjit_backend.get_runtime_libs())
        return jit_module
示例#2
0
    def compile(self, imported_module: Module):
        """Compiles an imported module, with a flat list of functions.
    The module is expected to be in linalg-on-tensors + scalar code form.
    TODO: More clearly define the backend contract. Generally this will
    extend to support globals, lists, and other stuff.

    Args:
      imported_module: The MLIR module consisting of funcs in the torch
        dialect.
    Returns:
      An opaque, backend specific module object that can be passed to load.
      The object may actually be something more specific to the backend (i.e.
      for IREE, it is a serialized VM flatbuffer) but the contract is that
      it is operated on by methods on this class.
    """
        with imported_module.context as context:
            if self._debug:
                logging.debug("IR passed to RefJIT compiler backend:\n{}",
                              imported_module)
            # Backend.
            # Note that this is a separate pass manager purely to aid in debugging.
            pm = PassManager()
            self._refjit.build_backend_compilation_pipeline(pm)
            pm.run(imported_module)
            if self._debug:
                logging.debug(
                    "RefBackend input IR (this is what the RefBackend compiler sees):\n{}",
                    imported_module)

        jit_module = self._refjit.JITModule.from_compiled_module(
            imported_module, refjit_backend.get_runtime_libs())
        return jit_module
示例#3
0
    def compile(self, imported_module: Module):
        """Compiles an imported module, with a flat list of functions.
    The module is expected to conform to the npcomp backend contract.
    See the VerifyBackendContract pass for more details.

    Args:
      imported_module: The MLIR module consisting of funcs in the torch
        dialect.
    Returns:
      An opaque, backend specific module object that can be passed to load.
      The object may actually be something more specific to the backend (i.e.
      for IREE, it is a serialized VM flatbuffer) but the contract is that
      it is operated on by methods on this class.
    """
        with imported_module.context as context:
            if self._debug:
                logging.debug("IR passed to IREE compiler backend:\n{}",
                              imported_module)
            pipeline_str = "npcomp-backend-to-iree-frontend-pipeline"
            if self._debug:
                logging.debug("Running Prepare For IREE pipeline '{}'",
                              pipeline_str)
            pm = PassManager.parse(pipeline_str)
            pm.run(imported_module)
            if self._debug:
                logging.debug(
                    "IREE Input IR (this is what IREE's compiler will see):\n{}",
                    imported_module)

            # Backend.
            binary = ireec.compile_str(str(imported_module),
                                       target_backends=["dylib-llvm-aot"])
        return binary
示例#4
0
    def compile(self, imported_module: Module):
        """Compiles an imported module.

    Args:
      legacy_imported_ir_module: The MLIR module as imported from the
        ImportFrontend.
    Returns:
      An opaque, backend specific module object that can be passed to load.
      The object may actually be something more specific to the backend (i.e.
      for IREE, it is a serialized VM flatbuffer) but the contract is that
      it is operated on by methods on this class.
    """
        with imported_module.context as context:
            # Frontend.
            if self._debug:
                logging.debug("Input IR:\n{}", imported_module)
            assert (imported_module.operation.verify()
                    ), "Imported module does not verify"
            pm = PassManager.parse(",".join(FRONTEND_PASSES))
            pm.run(imported_module)
            if self._debug:
                logging.debug("Frontend IR:\n{}", imported_module)

            # Backend.
            # Note that this is a separate pass manager purely to aid in debugging.
            pm = PassManager()
            self._refjit.build_backend_compilation_pipeline(pm)
            pm.run(imported_module)
            if self._debug:
                logging.debug("Backend IR:\n{}", imported_module)

        jit_module = self._refjit.JITModule.from_compiled_module(
            imported_module, refjit_backend.get_runtime_libs())
        return jit_module
def lower_object_graph(imported_module: Module):
    """Lowers an imported module that has TorchScript object graph semantics.

    Args:
        imported_module: The MLIR module consisting of IR as imported by the
        torch_mlir.import_module. It is lowered in place.
    Returns:
        The imported_module, for convenience chaining methods.
    """
    with imported_module.context as context:
        if logging.debug_enabled():
            logging.debug("Initial PyTorch object graph IR:\n{}",
                          imported_module)

        # Object graph lowering.
        pipeline_str = "torchscript-to-npcomp-backend-pipeline"
        if logging.debug_enabled():
            logging.debug("Running Torch object graph lowering pipeline '{}'",
                          pipeline_str)
        pm = PassManager.parse(pipeline_str)
        pm.run(imported_module)
    return imported_module
示例#6
0
    def compile(self, imported_module: Module):
        """Compiles an imported module.

    Args:
      imported_ir_module: The MLIR module as imported from the ImportFrontend.
    Returns:
      An opaque, backend specific module object that can be passed to load.
      The object may actually be something more specific to the backend (i.e.
      for IREE, it is a serialized VM flatbuffer) but the contract is that
      it is operated on by methods on this class.
    """
        with imported_module.context:
            # Frontend.
            if self._debug:
                logging.debug("Input IR:\n{}", imported_module)
            assert (imported_module.operation.verify()
                    ), "Imported module does not verify"
            # Frontend.
            pm = PassManager.parse(",".join(FRONTEND_PASSES))
            pm.run(imported_module)
            if self._debug:
                logging.debug("Frontend IR:{}", imported_module)

        # TODO: There should be some common utility for invoking backend processes
        # safely (and have options like saving temps, etc).
        args = [
            iree_backend.get_translate_exe(),
            "--iree-mlir-to-vm-bytecode-module"
        ]
        p = subprocess.Popen(args,
                             stdin=subprocess.PIPE,
                             stdout=subprocess.PIPE)
        imported_module.operation.print(binary=True,
                                        enable_debug_info=True,
                                        file=p.stdin)
        out, err = p.communicate()
        return out
def lower_module(imported_module: Module):
    """Compiles an imported module, with a flat list of functions.

    Args:
        imported_module: The MLIR module consisting of funcs and globals in
        the torch dialect. It is lowered in place.
    Returns:
        The imported_module, for convenience chaining methods.
    """
    with imported_module.context as context:
        if logging.debug_enabled():
            logging.debug("Initial PyTorch IR:\n{}", imported_module)
        # Frontend.
        pipeline_str = "torch-globalized-module-to-npcomp-backend-pipeline"
        if logging.debug_enabled():
            logging.debug("Running Torch->backend pipeline '{}'", pipeline_str)
        pm = PassManager.parse(pipeline_str)
        pm.run(imported_module)
        if logging.debug_enabled():
            logging.debug("Backend IR:\n{}", imported_module)
    return imported_module
示例#8
0
import torch
import torch_mlir

import npcomp
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering
from npcomp.compiler.utils import logging

import test_utils

logging.enable()

torch.manual_seed(0)
input = torch.rand(2, 3)

mb = torch_mlir.ModuleBuilder()
with mb.capture_function("cos", [input]) as f:
    result = torch.cos(input)
    f.returns([result])

backend = iree.IreeNpcompBackend()
jit_module = backend.load(
    backend.compile(frontend_lowering.lower_module(mb.module)))

logging.debug(f"Executing jit_module.cos")
test_utils.compare_outputs(torch.cos, jit_module.cos, input)

# This fails because ModuleBuilder represents torch.cos with a constant:
#   https://github.com/llvm/mlir-npcomp/issues/135
test_utils.compare_outputs(torch.cos, jit_module.cos, input + 1)