Example #1
0
def main():
    parser = argparse.ArgumentParser(description="Utility for .pt files")
    parser.add_argument("pt_file",
                        metavar="PT_FILE",
                        type=str,
                        help="the .pt file to import")
    parser.add_argument("--dump",
                        action="store_true",
                        help="dump the pytorch module")
    parser.add_argument("--import",
                        action="store_true",
                        help="import the pytorch module")
    parser.add_argument("--exported-name",
                        action="append",
                        help="""
Name to export, such as `my.submodule.forward`(default = export all).
Can pass repeatedly.
""")
    args = parser.parse_args()
    # TODO: Investigate why "cpu" is needed.
    module = torch.jit.load(args.pt_file, map_location="cpu")

    if args.dump:
        module._c.dump(code=True, attrs=False, params=False)

    # `import` is a Python keyword, so getattr is needed.
    if getattr(args, "import", False):
        class_annotator = torch_mlir.ClassAnnotator()
        if args.exported_name is not None:
            class_annotator.exportNone(module._c._type())
            for name in args.exported_name:
                class_annotator.exportPath(name.split("."), module._c._type())
        mb = torch_mlir.ModuleBuilder()
        mb.import_module(module._c, class_annotator)
        mb.module.operation.print(large_elements_limit=16)
        return


class TestModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.s = Submodule()

    def forward(self, tensor, value_tensor):
        return self.s.forward()


test_module = TestModule()
recursivescriptmodule = torch.jit.script(test_module)

annotator = torch_mlir.ClassAnnotator()
class_type = recursivescriptmodule._c._type()

annotator.exportNone(class_type)
annotator.exportPath(class_type, ['s', 'exported'])
annotator.exportPath(class_type, ['s', 'forward'])
annotator.annotateArgs(class_type, ['forward'], [
    None,
    ((1024, 2), torch.float32, False),
    ((42, -1, 7), torch.int8, True),
])

# "Change detector" test + "documentation" for the repr of `ClassAnnotator`.
# This is semi-load-bearing because users interact with this class and repr
# will show up in error messages, so should be pretty readable.
#
    def compile(self, program: torch.nn.Module) -> Any:
        mb = torch_mlir.ModuleBuilder()
        scripted = torch.jit.script(program)
        class_annotator = torch_mlir.ClassAnnotator()

        extract_annotations(program, scripted, class_annotator)

        # TODO: Find a way to make each of these calls own its own
        # "debuggable error report" situation.
        try:
            sys.stderr = StringIO()
            # Import the TorchScript module to MLIR
            mb.import_module(scripted._c, class_annotator)
        except Exception as e:
            raise Exception(f"""
PyTorch TorchScript module -> NPCOMP Object Graph IR import failed with:
Exception:
{e}
Diagnostics:
{sys.stderr.getvalue()}
""") from None
        finally:
            sys.stderr = sys.__stderr__

        try:
            sys.stderr = StringIO()
            asm_for_error_report = mb.module.operation.get_asm(
                large_elements_limit=10, enable_debug_info=True)
            pipeline_str = "torchscript-to-npcomp-backend-pipeline"
            # Lower module in place to make it ready for compiler backends.
            with mb.module.context:
                pm = PassManager.parse(pipeline_str)
                pm.run(mb.module)
        except Exception as e:
            # TODO: More robust.
            # - don't arbitrarily clutter up /tmp. When a test suite has many
            #   tests, this can be a big disk cost (also, /tmp/ is frequently a
            #   RAM fs, which increases worries about capacity).
            # - don't have colliding filenames (hard to do without cluttering
            #   up /tmp)
            # - if we do have have colliding filenames, writes should at least
            #   avoid being racy.
            filename = os.path.join(tempfile.gettempdir(),
                                    scripted.original_name + '.mlir')
            with open(filename, 'w') as f:
                f.write(asm_for_error_report)
            raise Exception(f"""
NPCOMP TorchScript Object Graph IR -> NPCOMP Backend IR lowering failed with the following diagnostics:
{sys.stderr.getvalue()}

Error can be reproduced with:
$ npcomp-opt -{pipeline_str} {filename}
""") from None
        finally:
            sys.stderr = sys.__stderr__
        try:
            sys.stderr = StringIO()
            asm_for_error_report = mb.module.operation.get_asm(
                large_elements_limit=10, enable_debug_info=True)
            return self.backend.compile(mb.module)
        except Exception as e:
            filename = os.path.join(tempfile.gettempdir(),
                                    scripted.original_name + '.mlir')
            with open(filename, 'w') as f:
                f.write(asm_for_error_report)
            raise Exception(f"""
NPCOMP Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics:
## Exception:
{e}

## Stderr:
{sys.stderr.getvalue()}

## Input IR has been saved in {filename}
""") from None
        finally:
            sys.stderr = sys.__stderr__