コード例 #1
0
 def visit_Attribute(self, node):
     self.generic_visit(node)
     attr_full_name = get_attr_full_name(node)
     if attr_full_name in self.replace_dict:
         new_api_name = self.replace_dict[attr_full_name]
         new_api_node = gast.parse(new_api_name).body[0].value
         print_info("\033[1;31mUpgrade API (%s->%s)\033[0m" %
                    (attr_full_name, new_api_name))
         return new_api_node
     return node
コード例 #2
0
def api_reader(path):
    print_info("\033[1;32m start to read data.json\033[0m")
    api_json = []
    try:
        path = os.path.join(os.path.dirname(__file__), path)
        with open(path, 'r') as fr:
            api_json = json.load(fr)
    except:
        print_info("\033[1;31m %s read data.json fail\033[0m")
        api_json = dict()
    return api_json
コード例 #3
0
 def visit_Call(self, node):
     self.generic_visit(node)
     attr_full_name = get_attr_full_name(node.func)
     if attr_full_name in self.modify_dict:
         self.counter_dict[attr_full_name]['count'] = self.counter_dict[
             attr_full_name].get('count', 0) + 1
         save_counter_dict(COUNTER_OUTPUT_PATH_ORI, self.counter_dict,
                           attr_full_name)
         print_info(
             "Counting old API once (%s)->hit No(%s) times" %
             (attr_full_name, self.counter_dict[attr_full_name]['count']))
     return node
コード例 #4
0
 def visit_Call(self, node):
     attribute_node = node.func
     attr_full_name = get_attr_full_name(attribute_node)
     if attr_full_name in self.modify_dict:
         if "add" not in self.modify_dict[attr_full_name]:
             return node
         add_dict = self.modify_dict[attr_full_name]["add"]
         add_keywords_to(node, add_dict)
         for param in add_dict:
             print_info("\033[1;33mAdd Params (%s) to API (%s)\033[0m" %
                        (param, attr_full_name))
     return node
コード例 #5
0
 def visit_Call(self, node):
     attribute_node = node.func
     attr_full_name = get_attr_full_name(attribute_node)
     if attr_full_name in self.modify_dict:
         if "rename" not in self.modify_dict[attr_full_name]:
             return node
         modify_dict = self.modify_dict[attr_full_name]["rename"]
         rename_keywords_to(node, modify_dict)
         for param in modify_dict:
             print_info(
                 "\033[1;33mRename Params (%s->%s) in API (%s)\033[0m" %
                 (param, modify_dict[param], attr_full_name))
     return node
コード例 #6
0
 def visit_Call(self, node):
     attribute_node = node.func
     attr_full_name = get_attr_full_name(attribute_node)
     if attr_full_name in self.modify_dict:
         if "delete" not in self.modify_dict[attr_full_name]:
             return node
         delete_dict = self.modify_dict[attr_full_name]["delete"]
         delete_keywords_from(node, delete_dict)
         for param in delete_dict:
             print_info(
                 "\033[1;33mDelete Params (%s) from API (%s)\033[0m" %
                 (param, attr_full_name))
     return node
コード例 #7
0
def check_modify_dict(dict_path):
    api_dict = api_reader(dict_path)
    if not api_dict:
        print("parser json dict error")
        exit(1)
    api_ls = []
    print_info("\033[1;32m start to lead api_dic\033[0m")
    for k, v in api_dict.items():
        api_ls.append(v["name"])

    output_path = os.path.join(os.path.dirname(__file__), OUTPUT_DICT_REPORT)
    if os.path.isfile(output_path) and os.access(output_path, os.R_OK):
        print("File exists and generate new report")
        os.remove(output_path)

    for each_api in api_ls:
        try:
            each_api_list = each_api.split('.')
            api_name = each_api_list[-1]
            folder_name = ".".join(i for i in each_api_list[:-1])
            code_str = "from {0} import {1}".format(folder_name, api_name)
            exec(code_str)
        except:
            info_str = "\033[1;31m api {0} error, we can't find api in target env\n\033[0m".format(
                each_api)
            print_info(info_str)
            if not os.path.exists(REPORT_FOLDER):
                print_info("\033[1;33m %s create output folder\033[0m")
                os.mkdir(REPORT_FOLDER)

            with open(output_path, 'a') as fw:
                fw.write(each_api + '\n')
コード例 #8
0
def check_target_env_api_json(api_json_path):
    api_dict = api_reader(api_json_path)
    if not api_dict:
        print("parser json dict error")
        exit(1)
    api_ls = []
    print_info("\033[1;32m start to lead api_dic\033[0m")
    for item in api_dict["Sheet1"]:
        api_ls.append(item["paddle2.0"])

    output_path = os.path.join(os.path.dirname(__file__), OUTPUT_PATH)
    if os.path.isfile(output_path) and os.access(output_path, os.R_OK):
        print("File exists and generate new report")
        os.remove(output_path)

    for each_api in api_ls:
        try:
            each_api_list = each_api.split('.')
            api_name = each_api_list[-1]
            folder_name = ".".join(i for i in each_api_list[:-1])
            code_str = "from {0} import {1}".format(folder_name, api_name)
            exec(code_str)
        except:
            info_str = "\033[1;31m %s api {0} error, we can't find api in target env\033[0m".format(
                each_api)
            print_info(info_str)
            if not os.path.exists(OUTPUT_PATH):
                print_info("\033[1;33m %s create output folder\033[0m")
                os.mkdirs(OUTPUT_PATH)

            with open(output_path, 'aw') as fw:
                fw.write(info_str)
コード例 #9
0
def main(upgrade_api_args):
    if not upgrade_api_args.get("args_file", None):
        print(
            "\033[1;34mPlease set config file!! Default path is api_upgrade_src/conf/upgrade.conf\033[0m"
        )
        exit(1)
    if not upgrade_api_args.get("modify_dict", None):
        print(
            "\033[1;34mPlease set modify_dict file!! Default path is api_upgrade_src/dict/modify.dict\033[0m"
        )
        exit(1)

    upgrade_config_dict = load_config(upgrade_api_args["args_file"])
    if not os.path.isfile(upgrade_config_dict["input_path"]):
        file_py_list = get_cur_file_list()
    else:
        file_py_list = upgrade_config_dict["input_path"]
    modify_dict = load_modify_dict(upgrade_api_args["modify_dict"])
    delete_list = load_delete_dict(upgrade_api_args["delete_dict"])
    delete_pattern = "|".join(delete_list)

    if isinstance(file_py_list, list):
        if PROCESS_ELSE_THREAD:
            executor = concurrent.futures.ProcessPoolExecutor(
                max_workers=MAX_WORKERS)
            future_list = []
            for path in file_py_list:
                # 生成future实例
                future = executor.submit(transformer_file, upgrade_config_dict,
                                         path, modify_dict, True,
                                         delete_pattern)
                future_list.append(future)
        else:  # multi_threading
            executor = concurrent.futures.ThreadPoolExecutor(
                max_workers=MAX_WORKERS)
            future_list = []
            for path in file_py_list:
                # 生成future实例
                future = executor.submit(transformer_file, upgrade_config_dict,
                                         path, modify_dict, True,
                                         delete_pattern)
                future_list.append(future)

        executor.shutdown()
        for future in concurrent.futures.as_completed(future_list):
            if future.exception() is not None:
                print_info(
                    "\033[1;31m parallel error with future exception: %s \033[0m"
                    % (future.exception()))
            else:
                future.result()
        print_info("\033[1;33m all done.\033[0m")

    elif file_py_list is None:
        print_info(
            "\033[1;31mInput error: input must be a directory or a python file\033[0m"
        )
    else:
        try:
            eventlet.monkey_patch()
            with eventlet.Timeout(30, False):
                try:
                    transformer_file(upgrade_config_dict,
                                     upgrade_config_dict["input_path"],
                                     modify_dict,
                                     is_dir=False,
                                     delete_pattern=delete_pattern)
                except Exception as e:
                    print_info(
                        "\033[1;31m %s upgrade error, please check file, use a replacement policy and convert it manually, with error %s. \033[0m"
                        % (upgrade_config_dict["input_path"], e))
        except Exception as e:
            print_info(
                "\033[1;31m %s upgrade timeout, please check file, use a replacement policy and convert it manually, with error %s.\033[0m"
                % (upgrade_config_dict["input_path"], e))
コード例 #10
0
def transformer_file(upgrade_config_dict,
                     input,
                     modify_dict=None,
                     is_dir=False,
                     delete_pattern=None):

    content = open(input, 'r').readlines()
    match = re.search(delete_pattern, "\n".join(content))
    if match:
        delete_api = match.group(0)
        print_info(
            "\033[1;31m %s API has been deleted, please check file %s, use a replacement policy and convert it manually\033[0m"
            % (delete_api, input))

    input = os.path.normpath(input)
    (dirpath, filename) = os.path.split(input)
    abs_path = os.path.abspath(input)

    if filename.startswith("."):
        return -1

    if is_dir:
        out_dir = os.path.join(upgrade_config_dict["output_path"], dirpath)
        if not os.path.isdir(out_dir):
            os.makedirs(out_dir, 0o777)
        out_file = os.path.join(out_dir, filename)
    else:
        out_file = upgrade_config_dict["output_path"]

    out_file = os.path.normpath(out_file)
    check_stat = check_paddle(input)
    if filename.endswith(".sh") or check_stat != 0:
        with open(out_file, 'w') as fw:
            print_info("\033[1;34mStart upgrading model %s\033[0m" % (input))
            fw.write(open(input, 'r').read())
            print_info(
                "\033[1;34mUpgrade Complete. The updated file %s has been written sucess\033[0m"
                % (out_file))
            return -1

    cache_file = None
    module_name = filename.rstrip(".py")
    #TODO basic strategy for avoiding _builtin_ module conflict, which need more general solution
    if module_name in BUILD_IN_FUN + [
            "bert", "word2vec", "yolov3", "ets", "bmn", "tall"
    ]:
        cache_file = "cache_%s" % filename
        cache_dir = os.path.join(dirpath, cache_file)
        shutil.copyfile(input, cache_dir)
        module_name = cache_file.rstrip(".py")
    else:
        module_name = filename.rstrip(".py")

    try:
        mdl_inst = importlib.import_module(module_name, package=dirpath)
    except Exception as e:
        print_info(
            "\033[1;32m%s, so we use another strategy to dynamically import module\033[0m"
            % e)
        module_name, dirpath = os.path.split(abs_path)
        print("-->module name and package name:", module_name, dirpath)
        spec = importlib.util.spec_from_file_location(module_name, abs_path)
        new_mdl_inst = importlib.util.module_from_spec(spec)
        mdl_inst = new_mdl_inst

    size = os.path.getsize(input)
    print_info("\033[1;34mStart upgrading model %s\033[0m" % (input))

    if size != 0:
        try:
            root = gast.parse(inspect.getsource(mdl_inst))
            from_count_visitor = FromCountVisitor(root)
            from_count_visitor.visit(root)
            future_count = from_count_visitor.from_import_count
            print_info("\033[1;34mfuture count is %s \033[0m" % (future_count))

            insert_import_module_with_postion(root,
                                              mdl_name="paddle",
                                              pos=future_count)

            import_dict = scan_module_import(root)
            root = replace_full_name(root, import_dict)

            root = transformer_root(root, modify_dict)
            with open(out_file, 'w', encoding="utf8") as fw:
                fw.write(astor.to_source(gast.gast_to_ast(root)))
        except Exception as e:
            print_info(
                '\033[1;33;41mParser and upgrade %s error!!, please check API and convert it manually, with error %s \033[0m'
                % (input, e))
            return -1
    else:
        with open(out_file, 'w') as fw:
            fw.write(open(filename, 'r').read())
    if cache_file is not None:
        os.remove(cache_dir)

    print_info(
        "\033[1;34mUpgrade Complete. The updated file %s has been written sucess\033[0m"
        % (out_file))
    print_info("")
コード例 #11
0
def main(upgrade_api_args): 
    if not upgrade_api_args.get("args_file", None): 
        print("\033[1;34mPlease set config file!! Default path is api_upgrade_src/conf/upgrade.conf\033[0m")
        exit(1)
    if not upgrade_api_args.get("modify_dict", None): 
        print("\033[1;34mPlease set modify_dict file!! Default path is api_upgrade_src/dict/modify.dict\033[0m")
        exit(1)

    upgrade_config_dict = load_config(upgrade_api_args["args_file"])
    if not os.path.isfile(upgrade_config_dict["input_path"]): 
        file_py_list = get_cur_file_list()
    else: 
        file_py_list = upgrade_config_dict["input_path"]
    modify_dict = load_modify_dict(upgrade_api_args["modify_dict"])
    delete_list = load_delete_dict(upgrade_api_args["delete_dict"])
    delete_pattern = "|".join(delete_list)
    
    if isinstance(file_py_list, list): 
        for path in file_py_list: 
            try:
                content = open(path, 'r').readlines()
            except Exception as e: 
                print_info("\033[1;31m %s \033[0m" % (path))
                raise e
            content = open(path, 'r').readlines()
            match = re.search(delete_pattern, "\n".join(content))
            if match: 
                delete_api = match.group(0)
                print_info("\033[1;31m %s API has been deleted, please check file %s, use a replacement policy and convert it manually\033[0m" % (delete_api, path))
            else: 
                try: 
                    eventlet.monkey_patch()
                    with eventlet.Timeout(30, False): 
                        try: 
                            transformer_file(upgrade_config_dict, path, modify_dict, is_dir=True)
                        except Exception as e: 
                            # comment for debugger
                            # logger.info("\033[1;31m %s upgrade error, please check file, use a replacement policy and convert it manually, with error: %s. \033[0m" % (path, e))
                            print_info("\033[1;31m %s upgrade error, please check file, use a replacement policy and convert it manually, with error: %s. \033[0m" % (path, e))
                except: 
                    print_info("\033[1;31m %s upgrade timeout, please check file, use a replacement policy and convert it manually\033[0m" % (path))
    elif file_py_list is None: 
        print_info("\033[1;31mInput error: input must be a directory or a python file\033[0m")
    else: 
        content = open(upgrade_config_dict["input_path"], 'r').readlines()
        match = re.search(delete_pattern, "\n".join(content))
        if match:
            delete_api = match.group(0)
            print_info("\033[1;31m %s API has been deleted, please check file %s, use a replacement policy and convert it manually\033[0m" % (delete_api, upgrade_config_dict["input_path"]))
        else: 
            try: 
                eventlet.monkey_patch()
                with eventlet.Timeout(30, False): 
                    try: 
                        transformer_file(upgrade_config_dict, upgrade_config_dict["input_path"], modify_dict, is_dir=False)
                    except: 
                        print_info("\033[1;31m %s upgrade error, please check file, use a replacement policy and convert it manually\033[0m" % (upgrade_config_dict["input_path"]))
            except: 
                print_info("\033[1;31m %s upgrade timeout, please check file, use a replacement policy and convert it manually\033[0m" % (upgrade_config_dict["input_path"]))