def main(): parser = argparse.ArgumentParser(description="Utility for .pt files") parser.add_argument("pt_file", metavar="PT_FILE", type=str, help="the .pt file to import") parser.add_argument("--dump", action="store_true", help="dump the pytorch module") parser.add_argument("--import", action="store_true", help="import the pytorch module") parser.add_argument("--exported-name", action="append", help=""" Name to export, such as `my.submodule.forward`(default = export all). Can pass repeatedly. """) args = parser.parse_args() # TODO: Investigate why "cpu" is needed. module = torch.jit.load(args.pt_file, map_location="cpu") if args.dump: module._c.dump(code=True, attrs=False, params=False) # `import` is a Python keyword, so getattr is needed. if getattr(args, "import", False): class_annotator = torch_mlir.ClassAnnotator() if args.exported_name is not None: class_annotator.exportNone(module._c._type()) for name in args.exported_name: class_annotator.exportPath(name.split("."), module._c._type()) mb = torch_mlir.ModuleBuilder() mb.import_module(module._c, class_annotator) mb.module.operation.print(large_elements_limit=16)
return class TestModule(torch.nn.Module): def __init__(self): super().__init__() self.s = Submodule() def forward(self, tensor, value_tensor): return self.s.forward() test_module = TestModule() recursivescriptmodule = torch.jit.script(test_module) annotator = torch_mlir.ClassAnnotator() class_type = recursivescriptmodule._c._type() annotator.exportNone(class_type) annotator.exportPath(class_type, ['s', 'exported']) annotator.exportPath(class_type, ['s', 'forward']) annotator.annotateArgs(class_type, ['forward'], [ None, ((1024, 2), torch.float32, False), ((42, -1, 7), torch.int8, True), ]) # "Change detector" test + "documentation" for the repr of `ClassAnnotator`. # This is semi-load-bearing because users interact with this class and repr # will show up in error messages, so should be pretty readable. #
def compile(self, program: torch.nn.Module) -> Any: mb = torch_mlir.ModuleBuilder() scripted = torch.jit.script(program) class_annotator = torch_mlir.ClassAnnotator() extract_annotations(program, scripted, class_annotator) # TODO: Find a way to make each of these calls own its own # "debuggable error report" situation. try: sys.stderr = StringIO() # Import the TorchScript module to MLIR mb.import_module(scripted._c, class_annotator) except Exception as e: raise Exception(f""" PyTorch TorchScript module -> NPCOMP Object Graph IR import failed with: Exception: {e} Diagnostics: {sys.stderr.getvalue()} """) from None finally: sys.stderr = sys.__stderr__ try: sys.stderr = StringIO() asm_for_error_report = mb.module.operation.get_asm( large_elements_limit=10, enable_debug_info=True) pipeline_str = "torchscript-to-npcomp-backend-pipeline" # Lower module in place to make it ready for compiler backends. with mb.module.context: pm = PassManager.parse(pipeline_str) pm.run(mb.module) except Exception as e: # TODO: More robust. # - don't arbitrarily clutter up /tmp. When a test suite has many # tests, this can be a big disk cost (also, /tmp/ is frequently a # RAM fs, which increases worries about capacity). # - don't have colliding filenames (hard to do without cluttering # up /tmp) # - if we do have have colliding filenames, writes should at least # avoid being racy. filename = os.path.join(tempfile.gettempdir(), scripted.original_name + '.mlir') with open(filename, 'w') as f: f.write(asm_for_error_report) raise Exception(f""" NPCOMP TorchScript Object Graph IR -> NPCOMP Backend IR lowering failed with the following diagnostics: {sys.stderr.getvalue()} Error can be reproduced with: $ npcomp-opt -{pipeline_str} {filename} """) from None finally: sys.stderr = sys.__stderr__ try: sys.stderr = StringIO() asm_for_error_report = mb.module.operation.get_asm( large_elements_limit=10, enable_debug_info=True) return self.backend.compile(mb.module) except Exception as e: filename = os.path.join(tempfile.gettempdir(), scripted.original_name + '.mlir') with open(filename, 'w') as f: f.write(asm_for_error_report) raise Exception(f""" NPCOMP Backend lowering for {self.backend.__class__.__name__} failed with the following diagnostics: ## Exception: {e} ## Stderr: {sys.stderr.getvalue()} ## Input IR has been saved in {filename} """) from None finally: sys.stderr = sys.__stderr__