def _convert_call(self, node, matched_api_name): """"Convert the call node.""" new_node = None code = pasta.dump(node) api_name = pasta.dump(node.func) warning_info = get_prompt_info(matched_api_name) if warning_info is None: warning_info = '' if matched_api_name in ALL_MAPPING: logger.info("Line %3d start converting API: %s", node.lineno, api_name) new_code = self.mapping_api(node) if new_code != code: try: new_node = pasta.parse(new_code).body[0].value # find the first call name new_api_name = new_code[:new_code.find('(')] detail_msg = self._get_detail_prompt_msg(node, new_node) if detail_msg: warning_info = detail_msg + ' ' + warning_info except AttributeError: new_node = pasta.parse(new_code).body[0] new_api_name = new_code self._process_log.info( node.lineno, node.col_offset, LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info)) else: logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info) self._process_log.warning( node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info)) return new_node
def parse_args(self, call_name: str, args_str: str): """ Parse call_name and args_str. Args: call_name (str): str of the call function, etc. args_str (str): str of args for function, which starts with '(' and end with ')'. Returns: OrderedDict, all args parsed. Raises: ValueError: If can not use ast to parse or the required parse node not type of ast.Call, or the given args_str not valid. """ # expr is REQUIRED to meet (**) format if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"): raise ValueError('[{}] is think as args str, it should start with "(" and end with ")"'.format(args_str)) try: ast_node = ast.parse("whatever_call_name" + args_str) call_node = ast_node.body[0].value if not isinstance(call_node, ast.Call): raise ValueError('call name with args str [{}] not instance of ast.Call'.format(args_str)) except: raise ValueError("can't parse code:\n{}".format(args_str)) # regard all actual parameter as one parameter if len(self.params) == 1: k = list(self.params.keys())[0] if k.startswith('*'): value = args_str[1:-1] return OrderedDict([(k, value), ("call_name", call_name)]) args = OrderedDict() # param which name not assigned param_iter = iter(self.params.keys()) if len(call_node.args) > len(self.params): raise ValueError('Parse args of torch in {}, but there is problems with params'.format(call_name)) for arg in call_node.args: if isinstance(arg, ast.Starred): logger.debug("Find *%s", arg.value.id) args['*'] = arg.value.id else: # remove \n args[next(param_iter)] = pasta.dump(arg).strip() # params which name is assigned for keyword in call_node.keywords: if keyword.arg is None: logger.info("Find **%s", keyword.value.id) args['**'] = keyword.value.id else: # remove \n args[keyword.arg] = pasta.dump(keyword.value).strip() args["call_name"] = call_name return args
def _convert_cell(self, cell_scope): """ Convert a PyTorch Module class into MindSpore Cell class. Args: cell_scope (pasta.base.Scope): The network class definition node inherits from torch.nn.Module. """ cell_ast_node = cell_scope.node line_no = cell_ast_node.lineno logger.info("Line %3d: start converting nn.Module %s", line_no, self._code_analyzer.get_name(cell_ast_node)) class_elements = self._code_analyzer.network_definitions()['cell'] # step1. update function definition for func_scope in class_elements.get(cell_scope, []): self._update_function_def(func_scope) # step2. update base name of class self._update_base_name(cell_scope)
def _convert_function(self, func_scope, is_forward): """ Convert a PyTorch function into MindSpore function. Args: func_scope (pasta.base.scope.Scope): The node scope of function definition. is_forward (boolean): If the function is defined in forward function in nn.Module in torch. """ func_ast_node = func_scope.node line_no = func_ast_node.lineno logger.info("Line %3d: start converting function %s()", line_no, func_ast_node.name) parent = func_scope.parent_scope.node self._stack.clear() self._new_call_nodes.clear() if parent: self._stack.append(parent) self._is_forward_function = is_forward self.visit(func_scope.node)
def convert_function(self, fun_name, fun, is_forward): """ Convert a PyTorch function into MindSpore function. Args: fun_name (str): The str of function name. fun (func): The function to convert. is_forward (bool): If the function is defined in forward function in nn.Module in torch. Returns: dict, old code and converted code map if convert happens, else {}. """ _, line_no = inspect.getsourcelines(fun) logger.info("Line %3d: start converting function %s()", line_no, fun_name) code = inspect.getsource(fun) code_saved = copy.copy(code) i = 0 while i < len(code): api_name = self.find_api(code, i, is_forward) if api_name: line_no1 = line_no + code[:i].count('\n') if api_name in ALL_MAPPING: logger.info("Line %3d start converting API: %s", line_no1, api_name) code, i = self.convert_api(code, i, api_name) self.convert_info += "[Convert][Line{:3d}] {} is converted.\n".format( line_no1, api_name) continue if api_name in ALL_UNSUPPORTED: warn_info = ". " + UNSUPPORTED_WARN_INFOS[ api_name] if api_name in UNSUPPORTED_WARN_INFOS else "" logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info) self.convert_info += "[Unconvert][Line{:3d}] {} didn't convert{}\n".format( line_no1, api_name, warn_info) i += 1 return {code_saved: code} if code_saved != code else {}
def convert(self, infile, output_dir, report_dir): """ Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir. Args: infile (str): The script to convert. output_dir (str): The path to save converted file. report_dir (str): The path to save report file. """ in_file_split = _path_split(infile) in_file_split[-1], _ = _get_name_ext(in_file_split[-1]) module_name = '.'.join(in_file_split) with open(infile, 'r') as file: content = ''.join(file.readlines()) self._infile = infile self._tree = pasta.parse(content) self._report.clear() try: logger.info("Script file is %s", infile) logger.info("Start converting %s", module_name) self._report.append('[Start Convert]') self._ast_editor = AstEditVisitor() self._ast_editor.process(self._tree) self._report.extend(self._ast_editor.get_logs()) self._report.append('[Convert Over]') dest_file = os.path.join(output_dir, os.path.basename(infile)) with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: file.write(pasta.dump(self._tree)) logger.info("Convert success. Result is wrote to %s.", dest_file) except ScriptNotSupport as error: self._report.append('[ScriptNotSupport] ' + error.message) self._report.append('[Convert failed]') raise error except Exception as error: self._report.clear() raise error finally: if self._report: dest_report_file = os.path.join( report_dir, '_'.join(os.path.basename(infile).split('.')[:-1]) + '_report.txt') with os.fdopen( os.open(dest_report_file, self.flags, self.modes), 'a') as file: file.write('\n'.join(self._report)) logger.info("Convert report is saved in %s", dest_report_file)
def convert_module(self, module_name, module, forward_list): """ Convert a PyTorch module code into MindSpore module code. Args: module_name (str): The module's name. module (module): The module to convert. forward_list (set): A set of forward function. Returns: dict, map of old code and converted code. """ _, line_no = inspect.getsourcelines(module) logger.info("Line {:3d}: start converting nn.Module {}".format( line_no, module_name)) mapped = {} for name, member in inspect.getmembers(module): if self.is_valid_function(module, member): is_forward = self.judge_forward( "{}.{}".format(module_name, name), forward_list) mapped.update(self.convert_function(name, member, is_forward)) return mapped
def convert(self, import_name, output_dir, report_dir): """ Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir. Args: import_name (str): The module from which to import the module to convert. output_dir (str): The path to save converted file. report_dir (str): The path to save report file. """ logger.info("Start converting %s", import_name) start_info = '[Start Convert]\n' module_info = 'The module is {}.\n'.format(import_name) import_mod = importlib.import_module(import_name) srcfile = inspect.getsourcefile(import_mod) logger.info("Script file is %s", srcfile) forward_list = set(ForwardCall(srcfile).calls) logger.debug("Forward_list: %s", forward_list) # replace python function under nn.Module mapping = self.get_mapping(import_mod, forward_list) code = inspect.getsource(import_mod) code = self.update_code_and_convert_info(code, mapping) convert_info_split = self.convert_info.splitlines(keepends=True) convert_info_split = sorted(convert_info_split) convert_info_split.insert(0, start_info) convert_info_split.insert(1, module_info) convert_info_split.append('[Convert Over]') self.convert_info = ''.join(convert_info_split) dest_file = os.path.join(output_dir, os.path.basename(srcfile)) with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: file.write(code) logger.info("Convert success. Result is wrote to %s.", dest_file) dest_report_file = os.path.join( report_dir, '_'.join(os.path.basename(srcfile).split('.')[:-1]) + '_report.txt') with os.fdopen(os.open(dest_report_file, self.flags, self.modes), 'a') as file: file.write(self.convert_info) logger.info("Convert report is saved in %s", dest_report_file)
def visit_Call(self, node): """Callback function when visit AST tree""" code = pasta.dump(node) api_name = pasta.dump(node.func) # The parent node first call is equal to this node, skip when parent node is replaced. # This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to # P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting. # Access from the penultimate element in reverse order. for parent_node in self._stack[-2::-1]: if parent_node in self._new_call_nodes and pasta.dump( parent_node).startswith(api_name): return parent = self._stack[-2] new_node = None new_code = code matched_api_name, match_case = self.match_api( node.func, self._is_forward_function) if match_case in [ ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED ]: warning_info = get_prompt_info(matched_api_name) if warning_info is None: warning_info = '' if matched_api_name in ALL_MAPPING: logger.info("Line %3d start converting API: %s", node.lineno, api_name) new_code = self.mapping_api(node) if new_code != code: try: new_node = pasta.parse(new_code).body[0].value # find the first call name new_api_name = new_code[:new_code.find('(')] except AttributeError: new_node = pasta.parse(new_code).body[0] new_api_name = new_code self._process_log.info( node.lineno, node.col_offset, LOG_FMT_CONVERT_WITH_TIPS % (api_name, new_api_name, warning_info)) else: logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warning_info) self._process_log.warning( node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warning_info)) elif match_case in [ ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND ]: self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, '')) else: pass if parent and new_node: update_line_col = _LineColEditVisitor() update_line_col.update(new_node, node) pasta.ast_utils.replace_child(parent, node, new_node) self._new_call_nodes.append(new_node) node = new_node self._stack[-1] = node try: self.generic_visit(node) except Exception: logger.error('original code:%s, new code:%s', code, new_code, exc_info=True) raise