def emitFunctionHook(self, func): # func has invalid names for export, skip the jitter check if func.name == "<lambda>" or "aten::" in func.name or not _inline_everything: return # disable the hook while we parse code, otherwise we will re-enter the hook with torch.jit._disable_emit_hooks(): try: src, constants = _jit_python_print(func) cu = torch.jit.CompilationUnit()._import(src, constants) func2 = getattr(cu, func.name) src2, constants2 = _jit_python_print(func2) self.assertMultiLineEqual(src, src2) except RuntimeError as e: if not self._isHookExceptionOk(e): raise
def emitFunctionHook(self, func): # func has invalid names for export, skip the jitter check if func.name == "<lambda>" or "aten::" in func.name: return # disable the hook while we parse code, otherwise we will re-enter the hook with self.disableEmitHook(): try: src, constants = _jit_python_print(func) cu = torch.jit.CompilationUnit()._import(src, constants) func2 = getattr(cu, func.name) src2, constants2 = _jit_python_print(func2) self.assertMultiLineEqual(src, src2) except RuntimeError as e: se = str(e) if "Could not export Python function" not in se and \ "closures are not exportable" not in se: raise
def getExportImportCopy(self, m, also_test_file=True, map_location=None): if isinstance(m, torch._C.Function): src, constants = _jit_python_print(m) cu = torch.jit.CompilationUnit()._import(src, constants) return getattr(cu, m.name) buffer = io.BytesIO() torch.jit.save(m, buffer) buffer.seek(0) imported = torch.jit.load(buffer, map_location=map_location) if not also_test_file: return imported with TemporaryFileName() as fname: imported.save(fname) return torch.jit.load(fname, map_location=map_location)