def refactor_import(q: Query, change_spec): """ 1. add "import paddle" if needed. 2. remove "import paddle.mod" if needed. 3. remove "import paddle.module as mod", and convert "mod.api" to "paddle.module.api" 4. remove "from paddle.module import api", and convert "api" to "paddle.module.api" """ # select import_name and import_from pattern = """ ( file_input< any* > | name_import=import_name< 'import' '{name}' > | as_import=import_name< 'import' ( module_name='{name}' | module_name=dotted_name< {dotted_name} any* > | dotted_as_name< ( module_name='{name}' | module_name=dotted_name< {dotted_name} any* > ) 'as' module_nickname=any > ) > | from_import=import_from< 'from' ( module_name='{name}' | module_name=dotted_name< {dotted_name} any* > ) 'import' ['('] ( import_as_name< module_import=any 'as' module_nickname=any >* | import_as_names< module_imports=any* > | module_import=any ) [')'] > | leaf_node=NAME ) """ _kwargs = {} _kwargs['name'] = 'paddle' _kwargs["dotted_name"] = " ".join(quoted_parts(_kwargs["name"])) _kwargs["power_name"] = " ".join(power_parts(_kwargs["name"])) pattern = pattern.format(**_kwargs) imports_map = {} paddle_imported = set() paddle_found = set() def _find_imports(node: LN, capture: Capture, filename: Filename): if not is_import(node): return True if capture and 'name_import' in capture: paddle_imported.add(filename) paddle_found.add(filename) if capture and ('module_import' in capture or 'module_imports' in capture or 'module_nickname' in capture): paddle_found.add(filename) if filename not in imports_map: imports_map[filename] = {} if 'module_import' in capture: leaf = capture['module_import'] if leaf.type == token.NAME: old_name = leaf.value.strip() new_name = str( capture['module_name']).strip() + '.' + old_name imports_map[filename][old_name] = new_name if 'module_imports' in capture: for leaf in capture['module_imports']: if leaf.type == token.NAME: old_name = leaf.value.strip() new_name = str( capture['module_name']).strip() + '.' + old_name imports_map[filename][old_name] = new_name if 'module_nickname' in capture: old_name = str(capture['module_nickname']).strip() new_name = str(capture['module_name']).strip() imports_map[filename][old_name] = new_name return True q.select(pattern).filter(_find_imports) # convert to full module path def _full_module_path(node: LN, capture: Capture, filename: Filename): if not (isinstance(node, Leaf) and node.type == token.NAME): return if filename not in imports_map: return logger.debug("{} [{}]: {}".format(filename, list(capture), node)) # skip import statement if utils.is_import_node(node): return # skip left operand in argument list if utils.is_argument_node(node) and utils.is_left_operand(node): return # skip if it's already a full module path if node.prev_sibling is not None and node.prev_sibling.type == token.DOT: return rename_dict = imports_map[filename] if node.value in rename_dict: # find old_name and new_name old_name = node.value new_name = rename_dict[old_name] if node.parent is not None: _node = utils.code_repr(new_name).children[0].children[0] _node.parent = None new_node = _node new_node.children[0].prefix = node.prefix if node.parent.type == python_symbols.power: node.replace(new_node.children) else: node.replace(new_node) log_info( filename, node.get_lineno(), "{} -> {}".format(utils.node2code(node), utils.node2code(new_node))) q.modify(_full_module_path) # remove as_import and from_import def _remove_import(node: LN, capture: Capture, filename: Filename): if not is_import(node): return _node = capture.get('as_import', None) or capture.get( 'from_import', None) if _node is not None: prefix = _node.prefix p = _node.parent _node.remove() log_warning(filename, p.get_lineno(), 'remove "{}"'.format(utils.node2code(_node))) # delete NEWLINE node after delete as_import or from_import if p and p.children and len( p.children) == 1 and p.children[0].type == token.NEWLINE: p.children[0].remove() # restore comment p.next_sibling.prefix = prefix + p.next_sibling.prefix q.modify(_remove_import) # add "import paddle" if needed def _add_import(node: LN, capture: Capture, filename: Filename): if node.type != python_symbols.file_input: return if filename in paddle_imported: return if filename in paddle_found: touch_import(None, 'paddle', node, force=True) log_info(filename, node.get_lineno(), 'add "import paddle"') paddle_imported.add(filename) q.modify(_add_import) return q
def transform_imports(query: 'Query', old_name: str, new_name: str) -> 'Query': params = dict( name=old_name, dotted_name=' '.join(bowler_helpers.quoted_parts(old_name)), power_name=' '.join(bowler_helpers.power_parts(old_name)), ) for modifier_class in modifiers: modifier = modifier_class(old_name=old_name, new_name=new_name) selector = modifier.selector.format(**params) query = query.select(selector).modify(modifier) return query