Ejemplo n.º 1
0
    def parse_args(self, call_name: str, args_str: str):
        """
        Parse call_name and args_str.

        Args:
            call_name (str): str of the call function, etc.
            args_str (str): str of args for function, which starts with '(' and end with ')'.

        Returns:
            OrderedDict, all args parsed.

        Raises:
            ValueError: If can not use ast to parse or the required parse node not type of ast.Call,
            or the given args_str not valid.
        """
        # expr is REQUIRED to meet (**) format
        if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"):
            raise ValueError('[{}] is think as args str, it should start with "(" and end with ")"'.format(args_str))

        try:
            ast_node = ast.parse("whatever_call_name" + args_str)
            call_node = ast_node.body[0].value
            if not isinstance(call_node, ast.Call):
                raise ValueError('call name with args str [{}] not instance of ast.Call'.format(args_str))
        except:
            raise ValueError("can't parse code:\n{}".format(args_str))

        # regard all actual parameter as one parameter
        if len(self.params) == 1:
            k = list(self.params.keys())[0]
            if k.startswith('*'):
                value = args_str[1:-1]
                return OrderedDict([(k, value), ("call_name", call_name)])

        args = OrderedDict()

        # param which name not assigned
        param_iter = iter(self.params.keys())
        if len(call_node.args) > len(self.params):
            raise ValueError('Parse args of torch in {}, but there is problems with params'.format(call_name))
        for arg in call_node.args:
            if isinstance(arg, ast.Starred):
                logger.debug("Find *%s", arg.value.id)
                args['*'] = arg.value.id
            else:
                # remove \n
                args[next(param_iter)] = pasta.dump(arg).strip()

        # params which name is assigned
        for keyword in call_node.keywords:
            if keyword.arg is None:
                logger.info("Find **%s", keyword.value.id)
                args['**'] = keyword.value.id
            else:
                # remove \n
                args[keyword.arg] = pasta.dump(keyword.value).strip()

        args["call_name"] = call_name
        return args
Ejemplo n.º 2
0
    def _judge_forward(self, func_scope):
        """
        Check if function is a forward function.

        Args:
            func_scope (pasta.base.scope.Scope): The node scope of function definition.

        Returns:
            boolean, True or False
        """
        is_forward = func_scope.node in self._forward_list.values()
        if is_forward:
            logger.debug("%s is a forward function", self._code_analyzer.get_name(func_scope))
        return is_forward
Ejemplo n.º 3
0
    def judge_forward(name, forward_list):
        """
        Check if function is a forward function.

        Args:
            name (str): The function name.
            forward_list (set): A set of forward function.

        Returns:
            bool, True or False
        """
        is_forward = name in forward_list or name.split(".")[-1] == "forward"
        if is_forward:
            logger.debug("%s is a forward function", name)
        return is_forward
Ejemplo n.º 4
0
    def convert(self, import_name, output_dir, report_dir):
        """
        Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.

        Args:
            import_name (str): The module from which to import the module to convert.
            output_dir (str): The path to save converted file.
            report_dir (str): The path to save report file.
        """
        logger.info("Start converting %s", import_name)
        start_info = '[Start Convert]\n'
        module_info = 'The module is {}.\n'.format(import_name)

        import_mod = importlib.import_module(import_name)
        srcfile = inspect.getsourcefile(import_mod)
        logger.info("Script file is %s", srcfile)

        forward_list = set(ForwardCall(srcfile).calls)
        logger.debug("Forward_list: %s", forward_list)

        # replace python function under nn.Module
        mapping = self.get_mapping(import_mod, forward_list)
        code = inspect.getsource(import_mod)
        code = self.update_code_and_convert_info(code, mapping)
        convert_info_split = self.convert_info.splitlines(keepends=True)
        convert_info_split = sorted(convert_info_split)
        convert_info_split.insert(0, start_info)
        convert_info_split.insert(1, module_info)
        convert_info_split.append('[Convert Over]')
        self.convert_info = ''.join(convert_info_split)

        dest_file = os.path.join(output_dir, os.path.basename(srcfile))
        with os.fdopen(os.open(dest_file, self.flags, self.modes),
                       'w') as file:
            file.write(code)
        logger.info("Convert success. Result is wrote to %s.", dest_file)

        dest_report_file = os.path.join(
            report_dir, '_'.join(os.path.basename(srcfile).split('.')[:-1]) +
            '_report.txt')
        with os.fdopen(os.open(dest_report_file, self.flags, self.modes),
                       'a') as file:
            file.write(self.convert_info)
        logger.info("Convert report is saved in %s", dest_report_file)