def add_future(node, symbol): root = fixer_util.find_root(node) for idx, node in enumerate(root.children): if node.type == syms.simple_stmt and \ len(node.children) > 0 and node.children[0].type == token.STRING: # skip over docstring continue names = check_future_import(node) if not names: # not a future statement; need to insert before this break if symbol in names: # already imported return import_ = fixer_util.FromImport('__future__', [Leaf(token.NAME, symbol, prefix=" ")]) # Place after any comments or whitespace. (copyright, shebang etc.) import_.prefix = node.prefix node.prefix = '' children = [import_, fixer_util.Newline()] root.insert_child(idx, Node(syms.simple_stmt, children))
def future_import(feature, node): """ This seems to work """ root = find_root(node) if does_tree_import(u"__future__", feature, node): return for idx, node in enumerate(root.children): if node.type == syms.simple_stmt and \ len(node.children) > 0 and node.children[0].type == token.STRING: # skip over docstring continue names = check_future_import(node) if not names: # not a future statement; need to insert before this break if feature in names: # already imported return import_ = FromImport(u'__future__', [Leaf(token.NAME, feature, prefix=" ")]) children = [import_, Newline()] root.insert_child(idx, Node(syms.simple_stmt, children))
def future_import2(feature, node): """ An alternative to future_import() which might not work ... """ root = find_root(node) if does_tree_import(u"__future__", feature, node): return insert_pos = 0 for idx, node in enumerate(root.children): if node.type == syms.simple_stmt and node.children and \ node.children[0].type == token.STRING: insert_pos = idx + 1 break for thing_after in root.children[insert_pos:]: if thing_after.type == token.NEWLINE: insert_pos += 1 continue prefix = thing_after.prefix thing_after.prefix = u"" break else: prefix = u"" import_ = FromImport(u"__future__", [Leaf(token.NAME, feature, prefix=u" ")]) children = [import_, Newline()] root.insert_child(insert_pos, Node(syms.simple_stmt, children, prefix=prefix))
def add_global_assignment_after_imports(_name, assignment, node): """ Big copy paste + modification from touch_import """ root = find_root(node) if find_binding(_name, root): return # figure out where to insert the assignment. # First try to find the first import and then skip to the last one. insert_pos = offset = 0 for idx, node in enumerate(root.children): if not is_import_ish_stmt(node): continue for offset, node2 in enumerate(root.children[idx:]): if not is_import_ish_stmt(node2): break insert_pos = idx + offset break # if there are no imports where we can insert, find the docstring. # if that also fails, we stick to the beginning of the file if insert_pos == 0: for idx, node in enumerate(root.children): if (node.type == syms.simple_stmt and node.children and node.children[0].type == token.STRING): insert_pos = idx + 1 break children = [assignment, Newline()] root.insert_child(insert_pos, Node(syms.simple_stmt, children))
def touch_import_top(package, name_to_import, node): """Works like `does_tree_import` but adds an import statement at the top if it was not imported (but below any __future__ imports). Calling this multiple times adds them in reverse order. Based on lib2to3.fixer_util.touch_import() """ root = find_root(node) if does_tree_import(package, name_to_import, root): return # Look for __future__ imports and insert below them found = False for name in [ 'absolute_import', 'division', 'print_function', 'unicode_literals' ]: if does_tree_import('__future__', name, root): found = True break if found: # At least one __future__ import. We want to loop until we've seen them # all. start, end = None, None for idx, node in enumerate(root.children): if check_future_import(node): start = idx # Start looping idx2 = start while node: node = node.next_sibling idx2 += 1 if not check_future_import(node): end = idx2 break break assert start is not None assert end is not None insert_pos = end else: # No __future__ imports for idx, node in enumerate(root.children): if node.type == syms.simple_stmt: # and node.children and node.children[0].type == token.STRING): break insert_pos = idx if package is None: import_ = Node(syms.import_name, [ Leaf(token.NAME, u"import"), Leaf(token.NAME, name_to_import, prefix=u" ") ]) else: import_ = FromImport(package, [Leaf(token.NAME, name_to_import, prefix=u" ")]) children = [import_, Newline()] root.insert_child(insert_pos, Node(syms.simple_stmt, children))
def touch_import_top(package, name_to_import, node): """Works like `does_tree_import` but adds an import statement at the top if it was not imported (but below any __future__ imports). Calling this multiple times adds them in reverse order. Based on lib2to3.fixer_util.touch_import() """ root = find_root(node) if does_tree_import(package, name_to_import, root): return # Look for __future__ imports and insert below them found = False for name in ['absolute_import', 'division', 'print_function', 'unicode_literals']: if does_tree_import('__future__', name, root): found = True break if found: # At least one __future__ import. We want to loop until we've seen them # all. start, end = None, None for idx, node in enumerate(root.children): if check_future_import(node): start = idx # Start looping idx2 = start while node: node = node.next_sibling idx2 += 1 if not check_future_import(node): end = idx2 break break assert start is not None assert end is not None insert_pos = end else: # No __future__ imports for idx, node in enumerate(root.children): if node.type == syms.simple_stmt: # and node.children and node.children[0].type == token.STRING): break insert_pos = idx if package is None: import_ = Node(syms.import_name, [ Leaf(token.NAME, u"import"), Leaf(token.NAME, name_to_import, prefix=u" ") ]) else: import_ = FromImport(package, [Leaf(token.NAME, name_to_import, prefix=u" ")]) children = [import_, Newline()] root.insert_child(insert_pos, Node(syms.simple_stmt, children))
def add_globals(self, node): """Add required globals to the root of node. Idempotent.""" if self.added_pyi_globals: return # TODO: get rid of this -- added to prevent adding .parsed_pyi.top_lines every time # we annotate a different function in the same file, but can break when we run the tool # twice on the same file. Have to do something like what touch_import does. self.added_pyi_globals = True imports, top_lines = self.parsed_pyi.imports, self.parsed_pyi.top_lines # Copy imports if not already present for pkg, names in imports: if names is None: # TODO: do ourselves, touch_import puts stuff above license headers touch_import(None, pkg, node) # == 'import pkg' else: for name in names: touch_import(pkg, name, node) root = find_root(node) import_idx = [ idx for idx, node in enumerate(root.children) if self.import_pattern.match(node) ] if import_idx: future_insert_pos = import_idx[0] top_insert_pos = import_idx[-1] + 1 else: future_insert_pos = top_insert_pos = 0 # first string (normally docstring) for idx, node in enumerate(root.children): if (node.type == syms.simple_stmt and node.children and node.children[0].type == token.STRING): future_insert_pos = top_insert_pos = idx + 1 break top_lines = '\n'.join(top_lines) top_lines = Util.parse_string(top_lines) # strips some newlines for offset, node in enumerate(top_lines.children[:-1]): root.insert_child(top_insert_pos + offset, node) # touch_import doesn't do proper order for __future__ pkg = '__future__' future_imports = [ n for n in self.future_imports if not does_tree_import(pkg, n, root) ] for offset, name in enumerate(future_imports): node = FromImport(pkg, [Leaf(token.NAME, name, prefix=" ")]) node = Node(syms.simple_stmt, [node, Newline()]) root.insert_child(future_insert_pos + offset, node)
def create_type_checking_import(package, name, node): # type: (str, str, Node) -> None """ Create import statement of the form `from <package> import <name>` within a TYPING_CHECK block Parameters ------------- package : str name : str Name of type being imported node : Node """ def is_type_checking_decl(node): # [Grammar] # if_stmt: 'if' namedexpr_test ':' suite ('elif' namedexpr_test ':' suite)* ['else' ':' suite] if not node.type == syms.if_stmt: return False stmt = str(node.children[1]).strip() if stmt in ('typing.TYPE_CHECKING', 'TYPE_CHECKING'): return True return False root = find_root(node) # figure out where to insert the new import. First try to find # the first import and then skip to the last one. type_checking_suite = None parent = root for idx, node in enumerate(root.children): if not is_type_checking_decl(node): continue type_checking_suite = node.children[3] parent = type_checking_suite insert_pos = len(type_checking_suite.children) - 1 break if type_checking_suite is None: # Generate a new TYPE_CHECKING block at the bottom of the current import block and return insert_pos = _get_bottom_of_imports(root) _new_type_check_with_import(package, name, root, insert_pos) return import_ = _generate_import_node(package, name, prefix=" ") children = [import_, Newline()] parent.insert_child(insert_pos, Node(syms.simple_stmt, children))
def indentation_step(node): u""" Dirty little trick to get the difference between each indentation level Implemented by finding the shortest indentation string (technically, the "least" of all of the indentation strings, but tabs and spaces mixed won't get this far, so those are synonymous.) """ r = find_root(node) # Collect all indentations into one set. all_indents = set(i.value for i in r.pre_order() if i.type == token.INDENT) if not all_indents: # nothing is indented anywhere, so we get to pick what we want return u" " # four spaces is a popular convention else: return min(all_indents)
def indentation_step(node): """ Dirty little trick to get the difference between each indentation level Implemented by finding the shortest indentation string (technically, the "least" of all of the indentation strings, but tabs and spaces mixed won't get this far, so those are synonymous.) """ r = find_root(node) # Collect all indentations into one set. all_indents = set(i.value for i in r.pre_order() if i.type == token.INDENT) if not all_indents: # nothing is indented anywhere, so we get to pick what we want return u" " # four spaces is a popular convention else: return min(all_indents)
def add_globals(self, node): """Add required globals to the root of node. Idempotent.""" if self.added_pyi_globals: return # TODO(tsudol): get rid of this -- added to prevent adding # .parsed_pyi.top_lines every time we annotate a different function in the # same file, but can break when we run the tool twice on the same file. Have # to do something like what touch_import does. self.added_pyi_globals = True imports, top_lines = self.parsed_pyi.imports, self.parsed_pyi.top_lines # Copy imports if not already present for pkg, names in imports: if names is None: # TODO(tsudol): do ourselves, touch_import puts stuff above license # headers. touch_import(None, pkg, node) # == 'import pkg' else: for name in names: touch_import(pkg, name, node) root = find_root(node) import_idx = [ idx for idx, node in enumerate(root.children) if self.import_pattern.match(node) ] if import_idx: insert_pos = import_idx[-1] + 1 else: insert_pos = 0 # first string (normally docstring) for idx, node in enumerate(root.children): if (node.type == syms.simple_stmt and node.children and node.children[0].type == token.STRING): insert_pos = idx + 1 break top_lines = '\n'.join(top_lines) top_lines = Util.parse_string(top_lines) # strips some newlines for offset, node in enumerate(top_lines.children[:-1]): root.insert_child(insert_pos + offset, node)
def type_by_import_stmt(package, name, node): # type: (str, str, Node) -> Optional[str] """ Return name that needs to be annotated based on available imports Parameters ------------- package : str name : str Name of type being imported node : Node Returns ------------- Optional[str] Returns valid name for type if current import statement exists, or None is there is no import statement """ root = find_root(node) import_info = find_import_info(package, name, root) if import_info: # name is being imported directly in the form of # from <package> import <name> return import_info.binding split_path = package.rsplit('.', 1) if len(split_path) > 1: # We have a package of the form 'pkg.mod' pkg = split_path[0] mod = split_path[1] else: # We have a package of the form 'mod' pkg = '' mod = package module_import_info = find_import_info(pkg, mod, root) if module_import_info: return '.'.join((module_import_info.binding, name)) return None
def add_future(node, symbol): root = find_root(node) for idx, node in enumerate(root.children): if node.type == syms.simple_stmt and \ len(node.children) > 0 and node.children[0].type == token.STRING: # skip over docstring continue names = check_future_import(node) if not names: # not a future statement; need to insert before this break if symbol in names: # already imported return import_ = FromImport('__future__', [Leaf(token.NAME, symbol, prefix=" ")]) children = [import_, Newline()] root.insert_child(idx, Node(syms.simple_stmt, children))
def create_import(package, name, node): # type: (str, str, Node) -> None """ Create import statement of the form `from <package> import <name>` Parameters ------------- package : str name : str Name of type being imported node : Node """ root = find_root(node) insert_pos = _get_bottom_of_imports(root) import_ = _generate_import_node(package, name) children = [import_, Newline()] root.insert_child(insert_pos, Node(syms.simple_stmt, children))
def add_globals(self, node): """Add required globals to the root of node. Idempotent.""" if self.added_pyi_globals: return # TODO: get rid of this -- added to prevent adding .parsed_pyi.top_lines every time # we annotate a different function in the same file, but can break when we run the tool # twice on the same file. Have to do something like what touch_import does. self.added_pyi_globals = True imports, top_lines = self.parsed_pyi.imports, self.parsed_pyi.top_lines # Copy imports if not already present for pkg, names in imports: if names is None: # TODO: do ourselves, touch_import puts stuff above license headers touch_import(None, pkg, node) # == 'import pkg' else: for name in names: touch_import(pkg, name, node) root = find_root(node) import_idx = [idx for idx, node in enumerate(root.children) if self.import_pattern.match(node)] if import_idx: insert_pos = import_idx[-1] + 1 else: insert_pos = 0 # first string (normally docstring) for idx, node in enumerate(root.children): if (node.type == syms.simple_stmt and node.children and node.children[0].type == token.STRING): insert_pos = idx + 1 break top_lines = '\n'.join(top_lines) top_lines = Util.parse_string(top_lines) # strips some newlines for offset, node in enumerate(top_lines.children[:-1]): root.insert_child(insert_pos + offset, node)
def future_import(feature, node): """ This seems to work """ root = find_root(node) if does_tree_import(u"__future__", feature, node): return # Look for a shebang or encoding line shebang_encoding_idx = None for idx, node in enumerate(root.children): # Is it a shebang or encoding line? if is_shebang_comment(node) or is_encoding_comment(node): shebang_encoding_idx = idx if node.type == syms.simple_stmt and \ len(node.children) > 0 and node.children[0].type == token.STRING: # skip over docstring continue names = check_future_import(node) if not names: # not a future statement; need to insert before this break if feature in names: # already imported return import_ = FromImport(u'__future__', [Leaf(token.NAME, feature, prefix=" ")]) if shebang_encoding_idx == 0 and idx == 0: # If this __future__ import would go on the first line, # detach the shebang / encoding prefix from the current first line. # and attach it to our new __future__ import node. import_.prefix = root.children[0].prefix root.children[0].prefix = u'' # End the __future__ import line with a newline and add a blank line # afterwards: children = [import_, Newline()] root.insert_child(idx, Node(syms.simple_stmt, children))
def future_import(feature, node): """ This seems to work """ root = find_root(node) if does_tree_import(u"__future__", feature, node): return # Look for a shebang or encoding line shebang_encoding_idx = None for idx, node in enumerate(root.children): # Is it a shebang or encoding line? if is_shebang_comment(node) or is_encoding_comment(node): shebang_encoding_idx = idx if node.type == syms.simple_stmt and \ len(node.children) > 0 and node.children[0].type == token.STRING: # skip over docstring continue names = check_future_import(node) if not names: # not a future statement; need to insert before this break if feature in names: # already imported return import_ = FromImport(u'__future__', [Leaf(token.NAME, feature, prefix=" ")]) if shebang_encoding_idx == 0 and idx == 0: # If this __future__ import would go on the first line, # detach the shebang / encoding prefix from the current first line. # and attach it to our new __future__ import node. import_.prefix = root.children[0].prefix root.children[0].prefix = u'' # End the __future__ import line with a newline and add a blank line # afterwards: children = [import_ , Newline()] root.insert_child(idx, Node(syms.simple_stmt, children))
def add_lambda_define(node, args): #文件头注释 temp = False from lib2to3.fixer_util import find_root overall = find_root(node) if (overall.children[0].type==syms.simple_stmt) \ and (overall.children[0].children[0].type==3): temp = overall.children[0].clone() if overall.children[0].parent: overall.children[0].remove() else: del overall.children[0] #insert id = 0 while overall.children[id].type == syms.lambdef: id += 1 result = Node(syms.lambdef, [ Leaf(1, 'TODO_PyObject'), Leaf(1, 'lambda_' + str(id), prefix=' '), ] + args + [Leaf(4, '\r\n')]) overall.insert_child(id, result) if temp: overall.insert_child(0, temp) return 'lambda_' + str(id)
def touch_import_top(package, name_to_import, node): """Works like `does_tree_import` but adds an import statement at the top if it was not imported (but below any __future__ imports) and below any comments such as shebang lines). Based on lib2to3.fixer_util.touch_import() Calling this multiple times adds the imports in reverse order. Also adds "standard_library.install_aliases()" after "from future import standard_library". This should probably be factored into another function. """ root = find_root(node) if does_tree_import(package, name_to_import, root): return # Ideally, we would look for whether futurize --all-imports has been run, # as indicated by the presence of ``from builtins import (ascii, ..., # zip)`` -- and, if it has, we wouldn't import the name again. # Look for __future__ imports and insert below them found = False for name in [ 'absolute_import', 'division', 'print_function', 'unicode_literals' ]: if does_tree_import('__future__', name, root): found = True break if found: # At least one __future__ import. We want to loop until we've seen them # all. start, end = None, None for idx, node in enumerate(root.children): if check_future_import(node): start = idx # Start looping idx2 = start while node: node = node.next_sibling idx2 += 1 if not check_future_import(node): end = idx2 break break assert start is not None assert end is not None insert_pos = end else: # No __future__ imports. # We look for a docstring and insert the new node below that. If no docstring # exists, just insert the node at the top. for idx, node in enumerate(root.children): if node.type != syms.simple_stmt: break if not is_docstring(node): # This is the usual case. break insert_pos = idx if package is None: import_ = Node(syms.import_name, [ Leaf(token.NAME, u"import"), Leaf(token.NAME, name_to_import, prefix=u" ") ]) else: import_ = FromImport(package, [Leaf(token.NAME, name_to_import, prefix=u" ")]) if name_to_import == u'standard_library': # Add: # standard_library.install_aliases() # after: # from future import standard_library install_hooks = Node(syms.simple_stmt, [ Node(syms.power, [ Leaf(token.NAME, u'standard_library'), Node(syms.trailer, [ Leaf(token.DOT, u'.'), Leaf(token.NAME, u'install_aliases') ]), Node(syms.trailer, [Leaf(token.LPAR, u'('), Leaf(token.RPAR, u')')]) ]) ]) children_hooks = [install_hooks, Newline()] else: children_hooks = [] # FromImport(package, [Leaf(token.NAME, name_to_import, prefix=u" ")]) children_import = [import_, Newline()] old_prefix = root.children[insert_pos].prefix root.children[insert_pos].prefix = u'' root.insert_child( insert_pos, Node(syms.simple_stmt, children_import, prefix=old_prefix)) if len(children_hooks) > 0: root.insert_child(insert_pos + 1, Node(syms.simple_stmt, children_hooks))
def transform(self, node, results): if 'main' in find_root(node).future_features: node.insert_child(0, Leaf(1, '#')) return node
def refactor_tree(self, tree, name): """Refactors a parse tree (modifying the tree in place). For compatible patterns the bottom matcher module is used. Otherwise the tree is traversed node-to-node for matches. Args: tree: a pytree.Node instance representing the root of the tree to be refactored. name: a human-readable name for this tree. Returns: True if the tree was modified, False otherwise. """ #类型处理 from fixer.pyobject_fix import recur_type recur_type(tree) #main处理 from fixer.main_fix import fix_main if "main" not in tree.future_features: fix_main(tree) #print的import from fixer.tools import add_include if "print" in tree.future_features: add_include(tree, "iostream") add_include(tree, "stdio.h") #在2to3外运行该函数依赖 from itertools import chain from lib2to3 import pytree, pygram from lib2to3.fixer_util import find_root for fixer in chain(self.pre_order, self.post_order): fixer.start_tree(tree, name) #use traditional matching for the incompatible fixers self.traverse_by(self.bmi_pre_order_heads, tree.pre_order()) self.traverse_by(self.bmi_post_order_heads, tree.post_order()) # obtain a set of candidate nodes match_set = self.BM.run(tree.leaves()) while any(match_set.values()): for fixer in self.BM.fixers: if fixer in match_set and match_set[fixer]: #sort by depth; apply fixers from bottom(of the AST) to top match_set[fixer].sort(key=pytree.Base.depth, reverse=True) if fixer.keep_line_order: #some fixers(eg fix_imports) must be applied #with the original file's line order match_set[fixer].sort(key=pytree.Base.get_lineno) for node in list(match_set[fixer]): if node in match_set[fixer]: match_set[fixer].remove(node) try: find_root(node) except ValueError: # this node has been cut off from a # previous transformation ; skip continue if node.fixers_applied and fixer in node.fixers_applied: # do not apply the same fixer again continue results = fixer.match(node) if results: new = fixer.transform(node, results) if new is not None: node.replace(new) #new.fixers_applied.append(fixer) for node in new.post_order(): # do not apply the fixer again to # this or any subnode if not node.fixers_applied: node.fixers_applied = [] node.fixers_applied.append(fixer) # update the original match set for # the added code new_matches = self.BM.run(new.leaves()) for fxr in new_matches: if not fxr in match_set: match_set[fxr] = [] match_set[fxr].extend(new_matches[fxr]) for fixer in chain(self.pre_order, self.post_order): fixer.finish_tree(tree, name) #from fixer.tools import new_line,recur #test_me=recur(tree,0) #print (test_me) return tree.was_changed
def transform(self, node, results): name = results.get("name") binding = results.get("binding") pre = results.get("pre") post = results.get("post") simple = results.get("simple") if simple: binding = find_binding("intern", find_root(node), "sys") binding.remove() return if binding: if not pre and not post: new_binding = find_binding("intern", find_root(node), "sys") new_binding.remove() return elif not pre and post: for ch in node.children: if type(ch) == pytree.Node: assert ch.children[0].prefix + "intern" \ == str(ch.children[0]) ch.children[0].remove() # intern assert ch.children[0].prefix + "," \ == str(ch.children[0]) ch.children[0].remove() # , return elif not post and pre: for ch in node.children: if type(ch) == pytree.Node: assert ch.children[-1].prefix + "intern" \ == str(ch.children[-1]) ch.children[-1].remove() # intern assert ch.children[-1].prefix + "," \ == str(ch.children[-1]) ch.children[-1].remove() # , return elif post and pre: for ch in node.children: if type(ch) == pytree.Node: for ch_ in ch.children: if ch_ and ch_.prefix + "intern" == str(ch_): last_ch_ = ch_.prev_sibling ch_.remove() # intern assert last_ch_.prefix + "," \ == str(last_ch_) last_ch_.remove() # , return syms = self.syms obj = results["obj"].clone() if obj.type == syms.arglist: newarglist = obj.clone() else: newarglist = pytree.Node(syms.arglist, [obj.clone()]) after = results["after"] if after: after = [n.clone() for n in after] new = pytree.Node(syms.power, [Name("intern")] + [pytree.Node(syms.trailer, [results["lpar"].clone(), newarglist, results["rpar"].clone()] + after)]) new.prefix = node.prefix return new
def transform(self, node, results): name = results.get("name") binding = results.get("binding") pre = results.get("pre") post = results.get("post") simple = results.get("simple") if simple: binding = find_binding("intern", find_root(node), "sys") binding.remove() return if binding: if not pre and not post: new_binding = find_binding("intern", find_root(node), "sys") new_binding.remove() return elif not pre and post: for ch in node.children: if type(ch) == pytree.Node: assert ch.children[0].prefix + "intern" \ == str(ch.children[0]) ch.children[0].remove() # intern assert ch.children[0].prefix + "," \ == str(ch.children[0]) ch.children[0].remove() # , return elif not post and pre: for ch in node.children: if type(ch) == pytree.Node: assert ch.children[-1].prefix + "intern" \ == str(ch.children[-1]) ch.children[-1].remove() # intern assert ch.children[-1].prefix + "," \ == str(ch.children[-1]) ch.children[-1].remove() # , return elif post and pre: for ch in node.children: if type(ch) == pytree.Node: for ch_ in ch.children: if ch_ and ch_.prefix + "intern" == str(ch_): last_ch_ = ch_.prev_sibling ch_.remove() # intern assert last_ch_.prefix + "," \ == str(last_ch_) last_ch_.remove() # , return syms = self.syms obj = results["obj"].clone() if obj.type == syms.arglist: newarglist = obj.clone() else: newarglist = pytree.Node(syms.arglist, [obj.clone()]) after = results["after"] if after: after = [n.clone() for n in after] new = pytree.Node(syms.power, [Name("intern")] + [ pytree.Node(syms.trailer, [ results["lpar"].clone(), newarglist, results["rpar"].clone() ] + after) ]) new.prefix = node.prefix return new
def touch_import_top(package, name_to_import, node): """Works like `does_tree_import` but adds an import statement at the top if it was not imported (but below any __future__ imports). Based on lib2to3.fixer_util.touch_import() Calling this multiple times adds the imports in reverse order. Also adds "standard_library.install_hooks()" after "from future import standard_library". This should probably be factored into another function. """ root = find_root(node) if does_tree_import(package, name_to_import, root): return # Ideally, we would look for whether futurize --all-imports has been run, # as indicated by the presence of ``from future.builtins import (ascii, ..., # zip)`` -- and, if it has, we wouldn't import the name again. # Look for __future__ imports and insert below them found = False for name in ['absolute_import', 'division', 'print_function', 'unicode_literals']: if does_tree_import('__future__', name, root): found = True break if found: # At least one __future__ import. We want to loop until we've seen them # all. start, end = None, None for idx, node in enumerate(root.children): if check_future_import(node): start = idx # Start looping idx2 = start while node: node = node.next_sibling idx2 += 1 if not check_future_import(node): end = idx2 break break assert start is not None assert end is not None insert_pos = end else: # No __future__ imports. # We look for a docstring and insert the new node below that. If no docstring # exists, just insert the node at the top. for idx, node in enumerate(root.children): if node.type != syms.simple_stmt: break if not (node.children and node.children[0].type == token.STRING): # This is the usual case. break insert_pos = idx if package is None: import_ = Node(syms.import_name, [ Leaf(token.NAME, u"import"), Leaf(token.NAME, name_to_import, prefix=u" ") ]) else: import_ = FromImport(package, [Leaf(token.NAME, name_to_import, prefix=u" ")]) if name_to_import == u'standard_library': # Add: # standard_library.install_hooks() # after: # from future import standard_library install_hooks = Node(syms.simple_stmt, [Node(syms.power, [Leaf(token.NAME, u'standard_library'), Node(syms.trailer, [Leaf(token.DOT, u'.'), Leaf(token.NAME, u'install_hooks')]), Node(syms.trailer, [Leaf(token.LPAR, u'('), Leaf(token.RPAR, u')')]) ]) ] ) children_hooks = [install_hooks, Newline()] else: children_hooks = [] FromImport(package, [Leaf(token.NAME, name_to_import, prefix=u" ")]) children_import = [import_, Newline()] root.insert_child(insert_pos, Node(syms.simple_stmt, children_import)) if len(children_hooks) > 0: root.insert_child(insert_pos + 1, Node(syms.simple_stmt, children_hooks))