def _create_simple_stmt_node_and_insert_behind(code, node): if node is None or node.type != python_symbols.simple_stmt: return simple_stmt_node = Node(python_symbols.simple_stmt, [utils.newline_node(node)]) _node = utils.code_repr(code).children[0].children[0] _node.parent = None simple_stmt_node.insert_child(0, _node) simple_stmt_node.prefix = utils.get_indent(node) utils.insert_node_behind(node, simple_stmt_node)
def _remove_with_dygraph_guard(node: LN, capture: Capture, filename: Filename): # index of with_node, with_node will be replaced with simple_stmt node with_node = capture['with'] parent = with_node.parent idx = None for i, child in enumerate(parent.children): if child is with_node: idx = i break # create simple_stmt node for "paddle.disable_static" arg_list_nodes = capture['arg_list'] simple_stmt_disable_static = Node(python_symbols.simple_stmt, [utils.newline_node(node)]) _node = utils.code_repr('paddle.disable_static' + str(arg_list_nodes)).children[0].children[0] _node.parent = None simple_stmt_disable_static.insert_child(0, _node) simple_stmt_disable_static.prefix = with_node.prefix # create simple_stmt node for "paddle.enable_static" simple_stmt_enable_static = Node(python_symbols.simple_stmt, [utils.newline_node(node)]) simple_stmt_enable_static _node = utils.code_repr( 'paddle.enable_static()').children[0].children[0] _node.parent = None simple_stmt_enable_static.insert_child(0, _node) simple_stmt_enable_static.prefix = utils.get_indent(with_node) suite_node = capture['suite'] # remove first newline for node in suite_node.children: if not isinstance(node, Leaf): continue if node.type == token.NEWLINE: node.remove() break # remove first indent node, and add indent prefix to sibling node. indent = None for node in suite_node.children: if not isinstance(node, Leaf): continue if node.type == token.INDENT: indent = node.value if node.next_sibling is not None: node.next_sibling.prefix = node.prefix + indent node.remove() break # transfer post leading dedent node prefix to sibling of with node leaves = [leaf for leaf in suite_node.leaves()] # visit all leaves in reversed order last_dedent_leaf_idx = len(leaves) for leaf in leaves[::-1]: if leaf.type == token.DEDENT: with_node.next_sibling.prefix = leaf.prefix + with_node.next_sibling.prefix leaf.prefix = "" else: break # remove dedenet node corresponding to with node for node in suite_node.children[::-1]: if not isinstance(node, Leaf): continue if node.type == token.DEDENT: node.remove() break # unindent all code in suite for node in suite_node.leaves(): if node.type == token.INDENT: node.value = utils.dec_indent(node.value) else: node.prefix = utils.dec_indent(node.prefix) with_node.remove() parent.insert_child(idx, simple_stmt_disable_static) idx += 1 for node in suite_node.children: parent.insert_child(idx, node) idx += 1 parent.insert_child(idx, simple_stmt_enable_static)
def sort_imports(root, capture, filename): statement_nodes = get_top_import_nodes(root) module_imports = [] # * Go through all top-of-file imports. # * Index them by module name. # * Do inline sorting of imported names (`import b, c, a` --> `import a, b, c`) for i, stmt in enumerate(statement_nodes): first_name = None imp = stmt.children[0] if imp.type == syms.import_name: module_node = imp.children[1] if module_node.type == TOKEN.NAME: # 'import os' module = module_node.value elif module_node.type == syms.dotted_name: # 'import x.y' module = str(module_node) elif module_node.type == syms.dotted_as_name: # 'import os as OS' module = module_node.children[0].value elif module_node.type == syms.dotted_as_names: # 'import os, io' module = _sort_imported_names(imp.children[1]) else: raise ValueError(f"Unknown import format: {imp}") elif imp.type == syms.import_from: module_node = imp.children[1] if module_node.type == syms.dotted_name: # 'from x.y import z' module = ''.join(c.value for c in module_node.children) else: module = module_node.value names = [n for n in imp.children[3:] if n.type != TOKEN.LPAR] if names[0].type == TOKEN.NAME: # 'from x import y' first_name = names[0].value elif names[0].type == syms.import_as_name: # 'from x import y as z' first_name = names[0].children[0].value elif names[0].type == syms.import_as_names: # 'from x import y, z' # 'from x import y as a, z as b' first_name = _sort_imported_names(names[0]) else: raise ValueError(f"Unknown import format: {imp}") else: # top-of-module docstring. float to top. module = '' module = module.strip() root_module_name = module.split('.')[0] # do 'from ...' imports after 'import ...' imports. from_ = 1 if first_name is not None else 0 if root_module_name == '': # module docstring group = Groups.DOCSTRING elif root_module_name == '__future__': # special case; must come first group = Groups.FUTURE elif root_module_name in STDLIB_MODULES: # stdlib modules group = Groups.STDLIB elif root_module_name not in cfg['first_party_modules']: # third party modules group = Groups.THIRD_PARTY else: # first party modules group = Groups.FIRST_PARTY # note: the `i` here is for a weird edge case where you try to sort # two of the same exact import. # turns out, Node instances aren't comparable, so we get an error if # the sort ever has to compare them. # So we insert a unique integer before them, thus preventing us ever having to # compare the node instances. module_imports.append((group, from_, module.lower(), first_name and first_name.lower(), i, stmt)) # Now sort the various lines we've encountered. module_imports.sort() # Now, clear out all the statements from the parse tree for n in statement_nodes: n.remove() # Then repopulate the tree with the sorted nodes, cleaning up whitespace as we go. last_group = 0 last_root_module_name = None for i, (group, from_, module_lower, first_name_lower, _, stmt_node) in enumerate(module_imports): assert len(stmt_node.children) == 2 root_module_name = module_lower.split('.')[0] import_node = stmt_node.children[0] newline_node = stmt_node.children[1] prefix = strip_prefix(import_node.prefix) if i != 0: if last_group != group: # add a space between groups. prefix = f'\n{prefix}' elif (last_root_module_name != root_module_name and group == Groups.FIRST_PARTY): # also add a space between different first-party projects. prefix = f'\n{prefix}' new_stmt = Node( syms.simple_stmt, [import_node.clone(), newline_node.clone()]) new_stmt.prefix = prefix root.insert_child(i, new_stmt) last_group = group last_root_module_name = root_module_name