def visit_Attribute(self, node): self.generic_visit(node) attr_full_name = get_attr_full_name(node) if attr_full_name in self.modify_dict: new_api_name = self.modify_dict[attr_full_name]['name'] new_api_node = gast.parse(new_api_name).body[0].value print_info("\033[1;33mUpgrade API (%s->%s)\033[0m" % (attr_full_name, new_api_name)) return new_api_node return node
def visit_Call(self, node): attribute_node = node.func attr_full_name = get_attr_full_name(attribute_node) if attr_full_name in self.modify_dict: if "rename" not in self.modify_dict[attr_full_name]: return node modify_dict = self.modify_dict[attr_full_name]["rename"] rename_keywords_to(node, modify_dict) for param in modify_dict: print_info("\033[1;33mRename Params (%s->%s) in API (%s)\033[0m" % (param, modify_dict[param], attr_full_name)) return node
def visit_Call(self, node): attribute_node = node.func attr_full_name = get_attr_full_name(attribute_node) if attr_full_name in self.modify_dict: if "delete" not in self.modify_dict[attr_full_name]: return node delete_dict = self.modify_dict[attr_full_name]["delete"] delete_keywords_from(node, delete_dict) for param in delete_dict: print_info("\033[1;33mDelete Params (%s) from API (%s)\033[0m" % (param, attr_full_name)) return node
def visit_Call(self, node): attribute_node = node.func attr_full_name = get_attr_full_name(attribute_node) if attr_full_name in self.modify_dict: if "add" not in self.modify_dict[attr_full_name]: return node add_dict = self.modify_dict[attr_full_name]["add"] add_keywords_to(node, add_dict) for param in add_dict: print_info("\033[1;33mAdd Params (%s) to API (%s)\033[0m" % (param, attr_full_name)) return node
def api_reader(path): print_info("\033[1;32m start to read data.json\033[0m") api_json = [] try: path = os.path.join(os.path.dirname(__file__), path) with open(path, 'r') as fr: api_json = json.load(fr) except: print_info("\033[1;31m %s read data.json fail\033[0m") api_json = dict() return api_json
def check_modify_dict(dict_path): api_dict = api_reader(dict_path) if not api_dict: print("parser json dict error") exit(1) api_ls = [] print_info("\033[1;32m start to lead api_dic\033[0m") for k, v in api_dict.items(): api_ls.append(v["name"]) output_path = os.path.join(os.path.dirname(__file__), OUTPUT_DICT_REPORT) if os.path.isfile(output_path) and os.access(output_path, os.R_OK): print("File exists and generate new report") os.remove(output_path) for each_api in api_ls: try: each_api_list = each_api.split('.') api_name = each_api_list[-1] folder_name = ".".join(i for i in each_api_list[:-1]) code_str = "from {0} import {1}".format(folder_name, api_name) exec(code_str) except: info_str = "\033[1;31m api {0} error, we can't find api in target env\n\033[0m".format( each_api) print_info(info_str) if not os.path.exists(REPORT_FOLDER): print_info("\033[1;33m %s create output folder\033[0m") os.mkdir(REPORT_FOLDER) with open(output_path, 'a') as fw: fw.write(each_api + '\n')
def check_target_env_api_json(api_json_path): api_dict = api_reader(api_json_path) if not api_dict: print("parser json dict error") exit(1) api_ls = [] print_info("\033[1;32m start to lead api_dic\033[0m") for item in api_dict["Sheet1"]: api_ls.append(item["paddle2.0"]) output_path = os.path.join(os.path.dirname(__file__), OUTPUT_PATH) if os.path.isfile(output_path) and os.access(output_path, os.R_OK): print("File exists and generate new report") os.remove(output_path) for each_api in api_ls: try: each_api_list = each_api.split('.') api_name = each_api_list[-1] folder_name = ".".join(i for i in each_api_list[:-1]) code_str = "from {0} import {1}".format(folder_name, api_name) exec(code_str) except: info_str = "\033[1;31m %s api {0} error, we can't find api in target env\033[0m".format( each_api) print_info(info_str) if not os.path.exists(OUTPUT_PATH): print_info("\033[1;33m %s create output folder\033[0m") os.mkdirs(OUTPUT_PATH) with open(output_path, 'aw') as fw: fw.write(info_str)
def transformer_file(upgrade_config_dict, input, modify_dict=None, is_dir=False): input = os.path.normpath(input) (dirpath, filename) = os.path.split(input) abs_path = os.path.abspath(input) if filename.startswith("."): return -1 if is_dir: out_dir = os.path.join(upgrade_config_dict["output_path"], dirpath) if not os.path.isdir(out_dir): os.makedirs(out_dir, 0o777) out_file = os.path.join(out_dir, filename) else: out_file = upgrade_config_dict["output_path"] out_file = os.path.normpath(out_file) check_stat = check_paddle(input) if filename.endswith(".sh") or check_stat != 0: with open(out_file, 'w') as fw: print_info("\033[1;34mStart upgrading model %s\033[0m" % (input)) fw.write(open(input, 'r').read()) print_info( "\033[1;34mUpgrade Complete. The updated file %s has been written sucess\033[0m" % (out_file)) return -1 cache_file = None module_name = filename.rstrip(".py") #TODO basic strategy for avoiding _builtin_ module conflict, which need more general solution if module_name in BUILD_IN_FUN + ["bert", "word2vec", "yolov3"]: cache_file = "cache_%s" % filename cache_dir = os.path.join(dirpath, cache_file) shutil.copyfile(input, cache_dir) module_name = cache_file.rstrip(".py") else: module_name = filename.rstrip(".py") try: mdl_inst = importlib.import_module(module_name, package=dirpath) except Exception as e: print_info( "\033[1;32m%s, so we use another strategy to dynamically import module\033[0m" % e) module_name, dirpath = os.path.split(abs_path) print("-->module name and package name:", module_name, dirpath) spec = importlib.util.spec_from_file_location(module_name, abs_path) new_mdl_inst = importlib.util.module_from_spec(spec) mdl_inst = new_mdl_inst size = os.path.getsize(input) print_info("\033[1;34mStart upgrading model %s\033[0m" % (input)) if size != 0: try: root = gast.parse(inspect.getsource(mdl_inst)) from_count_visitor = FromCountVisitor(root) from_count_visitor.visit(root) future_count = from_count_visitor.from_import_count print_info("\033[1;34mfuture count is %s \033[0m" % (future_count)) #TODO under testing: root = insert_import_module(root) #currently works well with dygraph/transformer/model.py, #which correctly insert paddle right after __future__ import #to avoid error. insert_import_module_with_postion(root, mdl_name="paddle", pos=future_count) import_dict = scan_module_import(root) root = replace_full_name(root, import_dict) root = transformer_root(root, modify_dict) with open(out_file, 'w', encoding="utf8") as fw: fw.write(astor.to_source(gast.gast_to_ast(root))) except Exception as e: print_info( '\033[1;33;41mParser and upgrade %s error!!, please check API and convert it manually, with error %s \033[0m' % (input, e)) return -1 else: with open(out_file, 'w') as fw: fw.write(open(filename, 'r').read()) if cache_file is not None: os.remove(cache_dir) print_info( "\033[1;34mUpgrade Complete. The updated file %s has been written sucess\033[0m" % (out_file)) print_info("")
def main(upgrade_api_args): if not upgrade_api_args.get("args_file", None): print( "\033[1;34mPlease set config file!! Default path is translate_src/conf/upgrade.conf\033[0m" ) exit(1) if not upgrade_api_args.get("modify_dict", None): print( "\033[1;34mPlease set modify_dict file!! Default path is translate_src/dict/modify.dict\033[0m" ) exit(1) upgrade_config_dict = load_config(upgrade_api_args["args_file"]) if not os.path.isfile(upgrade_config_dict["input_path"]): file_py_list = get_cur_file_list() else: file_py_list = upgrade_config_dict["input_path"] modify_dict = load_modify_dict(upgrade_api_args["modify_dict"]) delete_list = load_delete_dict(upgrade_api_args["delete_dict"]) delete_pattern = "|".join(delete_list) if isinstance(file_py_list, list): for path in file_py_list: content = open(path, 'r').readlines() match = re.search(delete_pattern, "\n".join(content)) if match: delete_api = match.group(0) print_info( "\033[1;31m %s API has been deleted, please check file %s, use a replacement policy and convert it manually\033[0m" % (delete_api, path)) else: try: eventlet.monkey_patch() with eventlet.Timeout(30, False): try: transformer_file(upgrade_config_dict, path, modify_dict, is_dir=True) except Exception as e: print_info( "\033[1;31m %s upgrade error, please check file, use a replacement policy and convert it manually, with error %s.\033[0m" % (path, e)) except Exception as e: print_info( "\033[1;31m %s upgrade timeout, please check file, use a replacement policy and convert it manually, with error %s.\033[0m" % (path, e)) elif file_py_list is None: print_info( "\033[1;31mInput error: input must be a directory or a python file\033[0m" ) else: content = open(upgrade_config_dict["input_path"], 'r').readlines() match = re.search(delete_pattern, "\n".join(content)) if match: delete_api = match.group(0) print_info( "\033[1;31m %s API has been deleted, please check file %s, use a replacement policy and convert it manually\033[0m" % (delete_api, upgrade_config_dict["input_path"])) else: try: eventlet.monkey_patch() with eventlet.Timeout(30, False): try: transformer_file(upgrade_config_dict, upgrade_config_dict["input_path"], modify_dict, is_dir=False) except Exception as e: print_info( "\033[1;31m %s upgrade error, please check file, use a replacement policy and convert it manually, with error %s. \033[0m" % (upgrade_config_dict["input_path"], e)) except Exception as e: print_info( "\033[1;31m %s upgrade timeout, please check file, use a replacement policy and convert it manually, with error %s.\033[0m" % (upgrade_config_dict["input_path"], e))
def main(upgrade_api_args): if not upgrade_api_args.get("args_file", None): print( "\033[1;34mPlease set config file!! Default path is translate_src/conf/upgrade.conf\033[0m" ) exit(1) if not upgrade_api_args.get("modify_dict", None): print( "\033[1;34mPlease set modify_dict file!! Default path is translate_src/dict/modify.dict\033[0m" ) exit(1) upgrade_config_dict = load_config(upgrade_api_args["args_file"]) if not os.path.isfile(upgrade_config_dict["input_path"]): file_py_list = get_cur_file_list() else: file_py_list = upgrade_config_dict["input_path"] modify_dict = load_modify_dict(upgrade_api_args["modify_dict"]) delete_list = load_delete_dict(upgrade_api_args["delete_dict"]) delete_pattern = "|".join(delete_list) if isinstance(file_py_list, list): if PROCESS_OR_THREAD == "MULTI_PROCESS": executor = concurrent.futures.ProcessPoolExecutor( max_workers=MAX_WORKERS) future_list = [] for path in file_py_list: future = executor.submit(transformer_file, upgrade_config_dict, path, modify_dict, True, delete_pattern) # 生成future实例 future_list.append(future) elif PROCESS_OR_THREAD == "MULTI_THREAD": executor = concurrent.futures.ThreadPoolExecutor( max_workers=MAX_WORKERS) future_list = [] for path in file_py_list: future = executor.submit(transformer_file, upgrade_config_dict, path, modify_dict, True, delete_pattern) # 生成future实例 future_list.append(future) executor.shutdown() for future in concurrent.futures.as_completed(future_list): if future.exception() is not None: print_info( "\033[1;31m parallel error with future exception: %s \033[0m" % (future.exception())) else: future.result() print_info("\033[1;33m all done.\033[0m") elif file_py_list is None: print_info( "\033[1;31mInput error: input must be a directory or a python file\033[0m" ) else: try: eventlet.monkey_patch() with eventlet.Timeout(30, False): try: transformer_file(upgrade_config_dict, upgrade_config_dict["input_path"], modify_dict, is_dir=False, delete_pattern=delete_pattern) except Exception as e: print_info( "\033[1;31m %s upgrade error, please check file, use a replacement policy and convert it manually, with error %s. \033[0m" % (upgrade_config_dict["input_path"], e)) except Exception as e: print_info( "\033[1;31m %s upgrade timeout, please check file, use a replacement policy and convert it manually, with error %s.\033[0m" % (upgrade_config_dict["input_path"], e))