Esempio n. 1
0
    def _convert_call(self, node, matched_api_name):
        """"Convert the call node."""
        new_node = None
        code = pasta.dump(node)
        api_name = pasta.dump(node.func)
        warning_info = get_prompt_info(matched_api_name)
        if warning_info is None:
            warning_info = ''
        if matched_api_name in ALL_MAPPING:
            logger.info("Line %3d start converting API: %s", node.lineno,
                        api_name)
            new_code = self.mapping_api(node)
            if new_code != code:
                try:
                    new_node = pasta.parse(new_code).body[0].value
                    # find the first call name
                    new_api_name = new_code[:new_code.find('(')]
                    detail_msg = self._get_detail_prompt_msg(node, new_node)
                    if detail_msg:
                        warning_info = detail_msg + ' ' + warning_info
                except AttributeError:
                    new_node = pasta.parse(new_code).body[0]
                    new_api_name = new_code
                self._process_log.info(
                    node.lineno, node.col_offset, LOG_FMT_CONVERT_WITH_TIPS %
                    (api_name, new_api_name, warning_info))
        else:
            logger.warning("Line %3d: found unsupported API: %s%s",
                           node.lineno, api_name, warning_info)
            self._process_log.warning(
                node.lineno, node.col_offset,
                LOG_FMT_NOT_CONVERT % (api_name, warning_info))

        return new_node
Esempio n. 2
0
    def parse_args(self, call_name: str, args_str: str):
        """
        Parse call_name and args_str.

        Args:
            call_name (str): str of the call function, etc.
            args_str (str): str of args for function, which starts with '(' and end with ')'.

        Returns:
            OrderedDict, all args parsed.

        Raises:
            ValueError: If can not use ast to parse or the required parse node not type of ast.Call,
            or the given args_str not valid.
        """
        # expr is REQUIRED to meet (**) format
        if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"):
            raise ValueError('[{}] is think as args str, it should start with "(" and end with ")"'.format(args_str))

        try:
            ast_node = ast.parse("whatever_call_name" + args_str)
            call_node = ast_node.body[0].value
            if not isinstance(call_node, ast.Call):
                raise ValueError('call name with args str [{}] not instance of ast.Call'.format(args_str))
        except:
            raise ValueError("can't parse code:\n{}".format(args_str))

        # regard all actual parameter as one parameter
        if len(self.params) == 1:
            k = list(self.params.keys())[0]
            if k.startswith('*'):
                value = args_str[1:-1]
                return OrderedDict([(k, value), ("call_name", call_name)])

        args = OrderedDict()

        # param which name not assigned
        param_iter = iter(self.params.keys())
        if len(call_node.args) > len(self.params):
            raise ValueError('Parse args of torch in {}, but there is problems with params'.format(call_name))
        for arg in call_node.args:
            if isinstance(arg, ast.Starred):
                logger.debug("Find *%s", arg.value.id)
                args['*'] = arg.value.id
            else:
                # remove \n
                args[next(param_iter)] = pasta.dump(arg).strip()

        # params which name is assigned
        for keyword in call_node.keywords:
            if keyword.arg is None:
                logger.info("Find **%s", keyword.value.id)
                args['**'] = keyword.value.id
            else:
                # remove \n
                args[keyword.arg] = pasta.dump(keyword.value).strip()

        args["call_name"] = call_name
        return args
Esempio n. 3
0
    def _convert_cell(self, cell_scope):
        """
        Convert a PyTorch Module class into MindSpore Cell class.

        Args:
            cell_scope (pasta.base.Scope): The network class definition node inherits from torch.nn.Module.
        """
        cell_ast_node = cell_scope.node
        line_no = cell_ast_node.lineno
        logger.info("Line %3d: start converting nn.Module %s", line_no, self._code_analyzer.get_name(cell_ast_node))

        class_elements = self._code_analyzer.network_definitions()['cell']
        # step1. update function definition
        for func_scope in class_elements.get(cell_scope, []):
            self._update_function_def(func_scope)

        # step2. update base name of class
        self._update_base_name(cell_scope)
Esempio n. 4
0
    def _convert_function(self, func_scope, is_forward):
        """
        Convert a PyTorch function into MindSpore function.

        Args:
            func_scope (pasta.base.scope.Scope): The node scope of function definition.
            is_forward (boolean): If the function is defined in forward function in nn.Module in torch.
        """
        func_ast_node = func_scope.node
        line_no = func_ast_node.lineno
        logger.info("Line %3d: start converting function %s()", line_no, func_ast_node.name)

        parent = func_scope.parent_scope.node
        self._stack.clear()
        self._new_call_nodes.clear()
        if parent:
            self._stack.append(parent)

        self._is_forward_function = is_forward
        self.visit(func_scope.node)
Esempio n. 5
0
    def convert_function(self, fun_name, fun, is_forward):
        """
        Convert a PyTorch function into MindSpore function.

        Args:
            fun_name (str): The str of function name.
            fun (func): The function to convert.
            is_forward (bool): If the function is defined in forward function in nn.Module in torch.

        Returns:
            dict, old code and converted code map if convert happens, else {}.
        """
        _, line_no = inspect.getsourcelines(fun)
        logger.info("Line %3d: start converting function %s()", line_no,
                    fun_name)

        code = inspect.getsource(fun)
        code_saved = copy.copy(code)

        i = 0
        while i < len(code):
            api_name = self.find_api(code, i, is_forward)
            if api_name:
                line_no1 = line_no + code[:i].count('\n')
                if api_name in ALL_MAPPING:
                    logger.info("Line %3d start converting API: %s", line_no1,
                                api_name)
                    code, i = self.convert_api(code, i, api_name)
                    self.convert_info += "[Convert][Line{:3d}] {} is converted.\n".format(
                        line_no1, api_name)
                    continue
                if api_name in ALL_UNSUPPORTED:
                    warn_info = ". " + UNSUPPORTED_WARN_INFOS[
                        api_name] if api_name in UNSUPPORTED_WARN_INFOS else ""
                    logger.warning("Line %3d: found unsupported API: %s%s",
                                   line_no1, api_name, warn_info)
                    self.convert_info += "[Unconvert][Line{:3d}] {} didn't convert{}\n".format(
                        line_no1, api_name, warn_info)
            i += 1
        return {code_saved: code} if code_saved != code else {}
Esempio n. 6
0
    def convert(self, infile, output_dir, report_dir):
        """
        Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.

        Args:
            infile (str): The script to convert.
            output_dir (str): The path to save converted file.
            report_dir (str): The path to save report file.
        """
        in_file_split = _path_split(infile)
        in_file_split[-1], _ = _get_name_ext(in_file_split[-1])
        module_name = '.'.join(in_file_split)
        with open(infile, 'r') as file:
            content = ''.join(file.readlines())

        self._infile = infile
        self._tree = pasta.parse(content)
        self._report.clear()
        try:
            logger.info("Script file is %s", infile)
            logger.info("Start converting %s", module_name)
            self._report.append('[Start Convert]')
            self._ast_editor = AstEditVisitor()
            self._ast_editor.process(self._tree)
            self._report.extend(self._ast_editor.get_logs())
            self._report.append('[Convert Over]')
            dest_file = os.path.join(output_dir, os.path.basename(infile))
            with os.fdopen(os.open(dest_file, self.flags, self.modes),
                           'w') as file:
                file.write(pasta.dump(self._tree))
            logger.info("Convert success. Result is wrote to %s.", dest_file)
        except ScriptNotSupport as error:
            self._report.append('[ScriptNotSupport] ' + error.message)
            self._report.append('[Convert failed]')
            raise error
        except Exception as error:
            self._report.clear()
            raise error
        finally:
            if self._report:
                dest_report_file = os.path.join(
                    report_dir,
                    '_'.join(os.path.basename(infile).split('.')[:-1]) +
                    '_report.txt')
                with os.fdopen(
                        os.open(dest_report_file, self.flags, self.modes),
                        'a') as file:
                    file.write('\n'.join(self._report))
                logger.info("Convert report is saved in %s", dest_report_file)
Esempio n. 7
0
    def convert_module(self, module_name, module, forward_list):
        """
        Convert a PyTorch module code into MindSpore module code.

        Args:
            module_name (str): The module's name.
            module (module): The module to convert.
            forward_list (set): A set of forward function.

        Returns:
            dict, map of old code and converted code.
        """
        _, line_no = inspect.getsourcelines(module)
        logger.info("Line {:3d}: start converting nn.Module {}".format(
            line_no, module_name))

        mapped = {}
        for name, member in inspect.getmembers(module):
            if self.is_valid_function(module, member):
                is_forward = self.judge_forward(
                    "{}.{}".format(module_name, name), forward_list)
                mapped.update(self.convert_function(name, member, is_forward))
        return mapped
Esempio n. 8
0
    def convert(self, import_name, output_dir, report_dir):
        """
        Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.

        Args:
            import_name (str): The module from which to import the module to convert.
            output_dir (str): The path to save converted file.
            report_dir (str): The path to save report file.
        """
        logger.info("Start converting %s", import_name)
        start_info = '[Start Convert]\n'
        module_info = 'The module is {}.\n'.format(import_name)

        import_mod = importlib.import_module(import_name)
        srcfile = inspect.getsourcefile(import_mod)
        logger.info("Script file is %s", srcfile)

        forward_list = set(ForwardCall(srcfile).calls)
        logger.debug("Forward_list: %s", forward_list)

        # replace python function under nn.Module
        mapping = self.get_mapping(import_mod, forward_list)
        code = inspect.getsource(import_mod)
        code = self.update_code_and_convert_info(code, mapping)
        convert_info_split = self.convert_info.splitlines(keepends=True)
        convert_info_split = sorted(convert_info_split)
        convert_info_split.insert(0, start_info)
        convert_info_split.insert(1, module_info)
        convert_info_split.append('[Convert Over]')
        self.convert_info = ''.join(convert_info_split)

        dest_file = os.path.join(output_dir, os.path.basename(srcfile))
        with os.fdopen(os.open(dest_file, self.flags, self.modes),
                       'w') as file:
            file.write(code)
        logger.info("Convert success. Result is wrote to %s.", dest_file)

        dest_report_file = os.path.join(
            report_dir, '_'.join(os.path.basename(srcfile).split('.')[:-1]) +
            '_report.txt')
        with os.fdopen(os.open(dest_report_file, self.flags, self.modes),
                       'a') as file:
            file.write(self.convert_info)
        logger.info("Convert report is saved in %s", dest_report_file)
Esempio n. 9
0
    def visit_Call(self, node):
        """Callback function when visit AST tree"""
        code = pasta.dump(node)
        api_name = pasta.dump(node.func)

        # The parent node first call is equal to this node, skip when parent node is replaced.
        # This scenario occurs, for example, when out.view(out.size(0), -1) is first converted to
        # P.Reshape()(out, (out.size(0). -1)), will skip P.Reshape() in following visiting.
        # Access from the penultimate element in reverse order.
        for parent_node in self._stack[-2::-1]:
            if parent_node in self._new_call_nodes and pasta.dump(
                    parent_node).startswith(api_name):
                return
        parent = self._stack[-2]
        new_node = None
        new_code = code
        matched_api_name, match_case = self.match_api(
            node.func, self._is_forward_function)
        if match_case in [
                ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED
        ]:
            warning_info = get_prompt_info(matched_api_name)
            if warning_info is None:
                warning_info = ''
            if matched_api_name in ALL_MAPPING:
                logger.info("Line %3d start converting API: %s", node.lineno,
                            api_name)
                new_code = self.mapping_api(node)
                if new_code != code:
                    try:
                        new_node = pasta.parse(new_code).body[0].value
                        # find the first call name
                        new_api_name = new_code[:new_code.find('(')]
                    except AttributeError:
                        new_node = pasta.parse(new_code).body[0]
                        new_api_name = new_code
                    self._process_log.info(
                        node.lineno, node.col_offset,
                        LOG_FMT_CONVERT_WITH_TIPS %
                        (api_name, new_api_name, warning_info))
            else:
                logger.warning("Line %3d: found unsupported API: %s%s",
                               node.lineno, api_name, warning_info)
                self._process_log.warning(
                    node.lineno, node.col_offset,
                    LOG_FMT_NOT_CONVERT % (api_name, warning_info))

        elif match_case in [
                ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND
        ]:
            self._process_log.warning(node.lineno, node.col_offset,
                                      LOG_FMT_NOT_CONVERT % (api_name, ''))
        else:
            pass

        if parent and new_node:
            update_line_col = _LineColEditVisitor()
            update_line_col.update(new_node, node)
            pasta.ast_utils.replace_child(parent, node, new_node)
            self._new_call_nodes.append(new_node)

            node = new_node
            self._stack[-1] = node
        try:
            self.generic_visit(node)
        except Exception:
            logger.error('original code:%s, new code:%s',
                         code,
                         new_code,
                         exc_info=True)
            raise