def register_checker(pattern, checker, extra): if "python_minimum_version" in extra and sys.version_info < extra["python_minimum_version"]: return if "python_disabled_version" in extra and sys.version_info > extra["python_disabled_version"]: return pattern = patcomp.compile_pattern(pattern) collected_checkers.append((pattern, checker, extra))
class Util(object): return_expr = compile_pattern("""return_stmt< 'return' any >""") @classmethod def has_return_exprs(cls, node): """Traverse the tree below node looking for 'return expr'. Return True if at least 'return expr' is found, False if not. (If both 'return' and 'return expr' are found, return True.) """ results = {} if cls.return_expr.match(node, results): return True for child in node.children: if child.type not in (syms.funcdef, syms.classdef): if cls.has_return_exprs(child): return True return False driver = driver.Driver(pygram.python_grammar, convert=pytree.convert) @classmethod def parse_string(cls, text): """Use lib2to3 to parse text into a Node.""" text = text.strip() if not text: # self.driver.parse_string just returns the ENDMARKER Leaf, wrap in a Node # for consistency return Node(syms.file_input, [Leaf(token.ENDMARKER, '')]) # workaround: parsing text without trailing '\n' throws exception text += '\n' return cls.driver.parse_string(text)
def register_checker(pattern, checker, extra): if ('python_minimum_version' in extra and sys.version_info < extra['python_minimum_version']): return if ('python_disabled_version' in extra and sys.version_info > extra['python_disabled_version']): return pattern = patcomp.compile_pattern(pattern) collected_checkers.append((pattern, checker, extra))
def _RegisterPattern(self, method, pattern): """ Registers a new pattern and the handling method. :param str method: The method name to handle the node matching the pattern. :param str pattern: The pattern to match AST nodes. This follows the lib2to3 pattern syntax. """ from lib2to3.patcomp import compile_pattern self.patterns.append((method, compile_pattern(pattern)))
class Util: """Utility functions for working with Nodes.""" return_expr = compile_pattern("""return_stmt< 'return' any >""") @classmethod def has_return_exprs(cls, node): """Traverse the tree below node looking for 'return expr'. Args: node: The AST node at the root of the subtree. Returns: True if 'return' or 'return expr' is found, False otherwise. """ results = {} if cls.return_expr.match(node, results): return True for child in node.children: if child.type not in (syms.funcdef, syms.classdef): if cls.has_return_exprs(child): return True return False driver = driver.Driver(pygram.python_grammar, convert=pytree.convert) @classmethod def parse_string(cls, text): """Use lib2to3 to parse text into a Node.""" text = text.strip() if not text: # cls.driver.parse_string just returns the ENDMARKER Leaf, wrap in # a Node for consistency return Node(syms.file_input, [Leaf(token.ENDMARKER, '')]) # workaround: parsing text without trailing '\n' throws exception text += '\n' return cls.driver.parse_string(text)
def register_pattern(self, method, pattern): """Register method to handle given pattern. """ self.patterns.append((method, compile_pattern(pattern)))
class FixAnnotate(BaseFix): # This fixer is compatible with the bottom matcher. BM_compatible = True # This fixer shouldn't run by default. explicit = True PATTERN = FuncSignature.PATTERN counter = None if not os.getenv('MAXFIXES') else int(os.getenv('MAXFIXES')) def __init__(self, options, log): super(FixAnnotate, self).__init__(options, log) # ParsedPyi obtained from .pyi file self.parsed_pyi = None # Did we add globals required by pyi to the top of the py file self.added_pyi_globals = False self.logger = logging.getLogger('FixAnnotate') # Options below # List of things to import from "__future__" self.future_imports = tuple() # insert type annotations in PEP484 style. Otherwise insert as comments self._annotate_pep484 = False # Strip comments and, formatting from type annotations (False breaks comment output mode) self._strip_pyi_formatting = not self.annotate_pep484 @property def annotate_pep484(self): return self._annotate_pep484 @annotate_pep484.setter def annotate_pep484(self, value): self._annotate_pep484 = bool(value) self._strip_pyi_formatting = not self.annotate_pep484 def transform(self, node, results): assert self.parsed_pyi, 'must provide pyi_string' if FixAnnotate.counter is not None: if FixAnnotate.counter <= 0: return cur_sig = FuncSignature(node, results) if not self.can_annotate(cur_sig): return if FixAnnotate.counter is not None: FixAnnotate.counter -= 1 # Compute the annotation, or directly insert if not self.emit_as_comment annot = self.get_or_insert_annotation(cur_sig) if not self.annotate_pep484 and annot: if cur_sig.try_insert_comment_annotation(annot) and 'Any' in annot: touch_import('typing', 'Any', node) self.add_globals(node) def get_or_insert_annotation(self, cur_sig): """If self.annotate_pep484, insert, otherwise return as comment string""" arg_types = [] for i, arg_sig in enumerate(cur_sig.arg_sigs): pyi_sig = self.parsed_pyi.funcs[cur_sig.full_name] new_type = pyi_sig.arg_sigs[i].arg_type new_type = clean_clone(new_type, self._strip_pyi_formatting) if self.annotate_pep484: if new_type: arg_sig.insert_annotation(new_type) else: is_first = (i == 0) if new_type: arg_types.append(arg_sig.stars + str(new_type).strip()) elif self.infer_should_annotate(cur_sig, arg_sig, is_first): arg_types.append(arg_sig.stars + 'Any') pyi_sig = self.parsed_pyi.funcs[cur_sig.full_name] ret_type = pyi_sig.ret_type if not self.annotate_pep484: if not ret_type: ret_type = self.infer_ret_type(cur_sig) return '(' + ', '.join(arg_types) + ') -> ' + str(ret_type).strip() elif ret_type: cur_sig.insert_ret_annotation(ret_type) def can_annotate(self, cur_sig): if cur_sig.has_pep484_annotations or cur_sig.has_comment_annotations: self.logger.warning('already annotated, skipping %s', cur_sig) return False if cur_sig.full_name not in self.parsed_pyi.funcs: self.logger.warning('no signature for %s, skipping', cur_sig) return False pyi_sig = self.parsed_pyi.funcs[cur_sig.full_name] if not pyi_sig.has_pep484_annotations: self.logger.warning( 'ignoring pyi definition with no annotations: %s', pyi_sig) return False if not self.func_sig_compatible(cur_sig, pyi_sig): self.logger.warning('incompatible annotation, skipping %s', cur_sig) return False return True def add_globals(self, node): """Add required globals to the root of node. Idempotent.""" if self.added_pyi_globals: return # TODO: get rid of this -- added to prevent adding .parsed_pyi.top_lines every time # we annotate a different function in the same file, but can break when we run the tool # twice on the same file. Have to do something like what touch_import does. self.added_pyi_globals = True imports, top_lines = self.parsed_pyi.imports, self.parsed_pyi.top_lines # Copy imports if not already present for pkg, names in imports: if names is None: # TODO: do ourselves, touch_import puts stuff above license headers touch_import(None, pkg, node) # == 'import pkg' else: for name in names: touch_import(pkg, name, node) root = find_root(node) import_idx = [ idx for idx, node in enumerate(root.children) if self.import_pattern.match(node) ] if import_idx: future_insert_pos = import_idx[0] top_insert_pos = import_idx[-1] + 1 else: future_insert_pos = top_insert_pos = 0 # first string (normally docstring) for idx, node in enumerate(root.children): if (node.type == syms.simple_stmt and node.children and node.children[0].type == token.STRING): future_insert_pos = top_insert_pos = idx + 1 break top_lines = '\n'.join(top_lines) top_lines = Util.parse_string(top_lines) # strips some newlines for offset, node in enumerate(top_lines.children[:-1]): root.insert_child(top_insert_pos + offset, node) # touch_import doesn't do proper order for __future__ pkg = '__future__' future_imports = [ n for n in self.future_imports if not does_tree_import(pkg, n, root) ] for offset, name in enumerate(future_imports): node = FromImport(pkg, [Leaf(token.NAME, name, prefix=" ")]) node = Node(syms.simple_stmt, [node, Newline()]) root.insert_child(future_insert_pos + offset, node) @staticmethod def func_sig_compatible(cur_sig, pyi_sig): """Can cur_sig be annotated with the info in pyi_sig: number of arguments must match, they must have the same star signature and they can't be tuple arguments. """ if len(pyi_sig.arg_sigs) != len(cur_sig.arg_sigs): return False for pyi, cur in zip(pyi_sig.arg_sigs, cur_sig.arg_sigs): # Entirely skip functions that use tuple args if cur.is_tuple or pyi.is_tuple: return False # Stars are expected to match if cur.stars != pyi.stars: return False return True @staticmethod def infer_ret_type(cur_sig): """Heuristic for return value of a function.""" if cur_sig.short_name == '__init__' or not cur_sig.has_return_exprs: return 'None' return 'Any' @staticmethod def infer_should_annotate(func, arg, at_start): """Heuristic for whether arg (in func) should be annotated.""" if func.is_method and at_start and 'staticmethod' not in func.decorators: # Don't annotate the first argument if it's named 'self'. # Don't annotate the first argument of a class method. if 'self' == arg.name or 'classmethod' in func.decorators: return False return True def set_pyi_string(self, pyi_string): """Set the annotations the fixer will use""" self.parsed_pyi = self.parse_pyi_string(pyi_string) self.added_pyi_globals = False def parse_pyi_string(self, text): """Parse .pyi string, return as ParsedPyi""" tree = Util.parse_string(text) funcs = {} for node, match_results in generate_matches(tree, self.pattern): sig = FuncSignature(node, match_results) if sig.full_name in funcs: self.logger.warning('Ignoring redefinition: %s', sig) else: funcs[sig.full_name] = sig imports = [] for node, match_results in generate_top_matches( tree, self.import_pattern): imp = self.parse_top_import(node, match_results) if imp: imports.append(imp) top_lines = [] for node, match_results in generate_top_matches( tree, self.assign_pattern): text = str(node).strip() # hack to avoid shadowing real variables -- proper solution is more complicated, # use util.find_binding if 'TypeVar' in text or (text and '_' == text[0]): top_lines.append(text) else: self.logger.warning("ignoring %s", repr(text)) return ParsedPyi(tuple(imports), top_lines, funcs) assign_pattern = compile_pattern(""" simple_stmt< expr_stmt<any+> any* > """) import_pattern = compile_pattern(""" simple_stmt< ( import_from< 'from' pkg=any+ 'import' ['('] names=any [')'] > | import_name< 'import' pkg=any+ > ) any* > """) import_as_pattern = compile_pattern("""import_as_name<NAME 'as' NAME>""") def parse_top_import(self, node, results): """Takes result of import_pattern, returns component strings: Examples: 'from pkg import a,b,c' gives ('pkg', ('a', 'b', 'c')) 'import pkg' gives ('pkg', None) 'from pkg import a as b' or 'import pkg as pkg2' are not supported. """ # TODO: this might have to be generalized to "get top-level statements that aren't # class or function definitions": # _T = typing.TypeVar('_T') is used in pyis. # Still not clear what is and isn't valid in a pyi... Could we have a loop? pkg, names = results['pkg'], results.get('names', None) pkg = ''.join(map(str, pkg)).strip() if names: is_import_as = any( True for _ in generate_matches(names, self.import_as_pattern)) if is_import_as: # fixer_util.touch_import doesn't handle this # If necessary, will have to stick import at top of .py file self.logger.warning('Ignoring unhandled import-as: %s', repr(str(node).strip())) return None names = split_comma(names.leaves()) for name in names: assert 1 == len(name) assert name[0].type in (token.NAME, token.STAR) names = [name[0].value for name in names] return pkg, names
class FixAnnotate(BaseFix): # This fixer is compatible with the bottom matcher. BM_compatible = True # This fixer shouldn't run by default. explicit = True # The pattern to match. PATTERN = """ funcdef< 'def' name=any parameters< '(' [args=any] ')' > ':' suite=any+ > """ counter = None if not os.getenv('MAXFIXES') else int(os.getenv('MAXFIXES')) def transform(self, node, results): if FixAnnotate.counter is not None: if FixAnnotate.counter <= 0: return suite = results['suite'] children = suite[0].children # NOTE: I've reverse-engineered the structure of the parse tree. # It's always a list of nodes, the first of which contains the # entire suite. Its children seem to be: # # [0] NEWLINE # [1] INDENT # [2...n-2] statements (the first may be a docstring) # [n-1] DEDENT # # Comments before the suite are part of the INDENT's prefix. # # "Compact" functions (e.g. "def foo(x, y): return max(x, y)") # have a different structure that isn't matched by PATTERN. ## print('-'*60) ## print(node) ## for i, ch in enumerate(children): ## print(i, repr(ch.prefix), repr(ch)) # Check if there's already an annotation. for ch in children: if ch.prefix.lstrip().startswith('# type:'): return # There's already a # type: comment here; don't change anything. # Compute the annotation annot = self.make_annotation(node, results) # Insert '# type: {annot}' comment. # For reference, see lib2to3/fixes/fix_tuple_params.py in stdlib. if len(children) >= 2 and children[1].type == token.INDENT: children[1].prefix = '{}# type: {}\n{}'.format(children[1].value, annot, children[1].prefix) children[1].changed() if FixAnnotate.counter is not None: FixAnnotate.counter -= 1 # Also add 'from typing import Any' at the top. if 'Any' in annot: touch_import('typing', 'Any', node) def make_annotation(self, node, results): name = results['name'] assert isinstance(name, Leaf), repr(name) assert name.type == token.NAME, repr(name) decorators = self.get_decorators(node) is_method = self.is_method(node) if name.value == '__init__' or not self.has_return_exprs(node): restype = 'None' else: restype = 'Any' args = results.get('args') argtypes = [] if isinstance(args, Node): children = args.children elif isinstance(args, Leaf): children = [args] else: children = [] # Interpret children according to the following grammar: # (('*'|'**')? NAME ['=' expr] ','?)* stars = inferred_type = '' in_default = False at_start = True for child in children: if isinstance(child, Leaf): if child.value in ('*', '**'): stars += child.value elif child.type == token.NAME and not in_default: if not is_method or not at_start or 'staticmethod' in decorators: inferred_type = 'Any' else: # Always skip the first argument if it's named 'self'. # Always skip the first argument of a class method. if child.value == 'self' or 'classmethod' in decorators: pass else: inferred_type = 'Any' elif child.value == '=': in_default = True elif in_default and child.value != ',': if child.type == token.NUMBER: if re.match(r'\d+[lL]?$', child.value): inferred_type = 'int' else: inferred_type = 'float' # TODO: complex? elif child.type == token.STRING: if child.value.startswith(('u', 'U')): inferred_type = 'unicode' else: inferred_type = 'str' elif child.type == token.NAME and child.value in ('True', 'False'): inferred_type = 'bool' elif child.value == ',': if inferred_type: argtypes.append(stars + inferred_type) # Reset stars = inferred_type = '' in_default = False at_start = False if inferred_type: argtypes.append(stars + inferred_type) return '(' + ', '.join(argtypes) + ') -> ' + restype # The parse tree has a different shape when there is a single # decorator vs. when there are multiple decorators. DECORATED = "decorated< (d=decorator | decorators< dd=decorator+ >) funcdef >" decorated = compile_pattern(DECORATED) def get_decorators(self, node): """Return a list of decorators found on a function definition. This is a list of strings; only simple decorators (e.g. @staticmethod) are returned. If the function is undecorated or only non-simple decorators are found, return []. """ if node.parent is None: return [] results = {} if not self.decorated.match(node.parent, results): return [] decorators = results.get('dd') or [results['d']] decs = [] for d in decorators: for child in d.children: if isinstance(child, Leaf) and child.type == token.NAME: decs.append(child.value) return decs def is_method(self, node): """Return whether the node occurs (directly) inside a class.""" node = node.parent while node is not None: if node.type == syms.classdef: return True if node.type == syms.funcdef: return False node = node.parent return False RETURN_EXPR = "return_stmt< 'return' any >" return_expr = compile_pattern(RETURN_EXPR) def has_return_exprs(self, node): """Traverse the tree below node looking for 'return expr'. Return True if at least 'return expr' is found, False if not. (If both 'return' and 'return expr' are found, return True.) """ results = {} if self.return_expr.match(node, results): return True for child in node.children: if child.type not in (syms.funcdef, syms.classdef): if self.has_return_exprs(child): return True return False
class FuncSignature: """A function or method.""" _full_name: str # The pattern to match. PATTERN = """ funcdef< 'def' name=NAME parameters< '(' [args=any+] ')' > ['->' ret_annotation=any] colon=':' suite=any+ > """ def __init__(self, node, match_results): """node must match PATTERN.""" name = match_results.get('name') assert isinstance(name, Leaf), repr(name) assert name.type == token.NAME, repr(name) self._ret_type = match_results.get('ret_annotation') self._full_name = self._make_function_key(name) args = self._split_args(match_results.get('args')) self._arg_sigs = tuple(map(ArgSignature, args)) self._node = node self._match_results = match_results self._inserted_ret_annotation = False def __str__(self): return self.full_name @property def full_name(self): """Fully-qualified name string.""" return self._full_name @property def short_name(self): return self._match_results.get('name').value @property def ret_type(self): """Return type, Node? or None.""" return self._ret_type @property def arg_sigs(self): """List[ArgSignature].""" return self._arg_sigs # The parse tree has a different shape when there is a single # decorator vs. when there are multiple decorators. decorated_pattern = compile_pattern(""" decorated< (d=decorator | decorators< dd=decorator+ >) funcdef > """) @property def decorators(self): """A list of the function's decorators. This is a list of strings; only simple decorators (e.g. @staticmethod) are returned. If the function is undecorated or only non-simple decorators are found, return []. Returns: The names of the function's decorators as a list of strings. Only simple decorators (e.g. @staticmethod) are returned. If the function is not decorated or only non-simple decorators are found, return []. """ # TODO(tsudol): memoize node = self._node if node.parent is None: return [] results = {} if not self.decorated_pattern.match(node.parent, results): return [] decorators = results.get('dd') or [results['d']] decs = [] for d in decorators: for child in d.children: if child.type == token.NAME: decs.append(child.value) return decs @property def is_method(self): """Whether we are (directly) inside a class.""" # TODO(tsudol): memoize node = self._node.parent while node is not None: if node.type == syms.classdef: return True if node.type == syms.funcdef: return False node = node.parent return False @property def has_return_exprs(self): """True if function has "return expr" anywhere.""" return Util.has_return_exprs(self._node) @property def has_pep484_annotations(self): """Do we have any pep484 annotations?""" return self.ret_type or any(arg.arg_type for arg in self.arg_sigs) @property def has_comment_annotations(self): """Do we have any comment annotations?""" children = self._match_results['suite'][0].children for ch in children: if ch.prefix.lstrip().startswith('# type:'): return True return False def insert_ret_annotation(self, ret_type): """In-place annotation. Can only be called once.""" assert not self._inserted_ret_annotation self._inserted_ret_annotation = True colon = self._match_results.get('colon') # TODO(tsudol): insert as a Node, not as a prefix colon.prefix = ' -> ' + str(ret_type).strip() + colon.prefix def try_insert_comment_annotation(self, annotation): """Try to insert '# type: {annotation}' comment.""" # For reference, see lib2to3/fixes/fix_tuple_params.py in stdlib. # "Compact" functions (e.g. 'def foo(x, y): return max(x, y)') # are not annotated. children = self._match_results['suite'][0].children if not (len(children) >= 2 and children[1].type == token.INDENT): return False # can't annotate node = children[1] node.prefix = '%s# type: %s\n%s' % (node.value, annotation, node.prefix) node.changed() return True scope_pattern = compile_pattern("""( funcdef < 'def' name=TOKEN any*> | classdef< 'class' name=TOKEN any*> )""") @classmethod def _make_function_key(cls, node): """Return the fully-qualified name of the function the node is under. If source is class C: def foo(self): x = 1 We'll return 'C.foo' for any nodes related to 'x', '1', 'foo', 'self', and either 'C' or '' otherwise. Args: node: The node to start searching from. Returns: The function key as a string. """ result = [] while node is not None: match_result = {} if cls.scope_pattern.match(node, match_result): result.append(match_result['name'].value) node = node.parent return '.'.join(reversed(result)) @staticmethod def _split_args(args): """Turns the match of PATTERN.args into a list of non-empty lists of nodes. Args: args: The value matched by PATTERN.args. Returns: A list of non-empty lists of nodes, where each list corresponds to a function argument. """ if args is None: return [] assert isinstance(args, list) and len(args) == 1, repr(args) args = args[0] if isinstance(args, Leaf) or args.type == syms.tname: args = [args] else: args = args.children return split_comma(args)
class BaseFixAnnotate(BaseFix): # This fixer is compatible with the bottom matcher. BM_compatible = True # This fixer shouldn't run by default. explicit = True # The pattern to match. PATTERN = """ funcdef< 'def' name=any parameters=parameters< '(' [args=any] rpar=')' > ':' suite=any+ > """ _maxfixes = os.getenv('MAXFIXES') counter = None if not _maxfixes else int(_maxfixes) _type_options = None # type: Optional[Dict[str, Any]] @property def type_options(self): if self._type_options is None: self._type_options = self.options.get('typewriter', {}) return self._type_options def should_skip(self, node, results): if BaseFixAnnotate.counter is not None: if BaseFixAnnotate.counter <= 0: return True # Check if there's already a long-form annotation for some argument. parameters = results.get('parameters') if parameters is not None: for ch in parameters.pre_order(): if ch.prefix.lstrip().startswith('# type:'): return True args = results.get('args') if args is not None: for ch in args.pre_order(): if ch.prefix.lstrip().startswith('# type:'): return True children = results['suite'][0].children # NOTE: I've reverse-engineered the structure of the parse tree. # It's always a list of nodes, the first of which contains the # entire suite. Its children seem to be: # # [0] NEWLINE # [1] INDENT # [2...n-2] statements (the first may be a docstring) # [n-1] DEDENT # # Comments before the suite are part of the INDENT's prefix. # # "Compact" functions (e.g. "def foo(x, y): return max(x, y)") # have a different structure (no NEWLINE, INDENT, or DEDENT). # Check if there's already an annotation. for ch in children: if ch.prefix.lstrip().startswith('# type:'): return True # There's already a # type: comment here; don't change anything. # Python 3 style return annotation are already skipped by the pattern # Python 3 style argument annotation structure # # Structure of the arguments tokens for one positional argument without default value : # + LPAR '(' # + NAME_NODE_OR_LEAF arg1 # + RPAR ')' # # NAME_NODE_OR_LEAF is either: # 1. Just a leaf with value NAME # 2. A node with children: NAME, ':", node expr or value leaf # # Structure of the arguments tokens for one args with default value or multiple # args, with or without default value, and/or with extra arguments : # + LPAR '(' # + node # [ # + NAME_NODE_OR_LEAF # [ # + EQUAL '=' # + node expr or value leaf # ] # ( # + COMMA ',' # + NAME_NODE_OR_LEAF positional argn # [ # + EQUAL '=' # + node expr or value leaf # ] # )* # ] # [ # + STAR '*' # [ # + NAME_NODE_OR_LEAF positional star argument name # ] # ] # [ # + COMMA ',' # + DOUBLESTAR '**' # + NAME_NODE_OR_LEAF positional keyword argument name # ] # + RPAR ')' # Let's skip Python 3 argument annotations it = iter(args.children) if args else iter([]) for ch in it: if ch.type == token.STAR: # *arg part ch = next(it) if ch.type == token.COMMA: continue elif ch.type == token.DOUBLESTAR: # *arg part ch = next(it) if ch.type > 256: # this is a node, therefore an annotation assert ch.children[0].type == token.NAME return True try: ch = next(it) if ch.type == token.COLON: # this is an annotation return True elif ch.type == token.EQUAL: ch = next(it) ch = next(it) assert ch.type == token.COMMA continue except StopIteration: break return False def transform(self, node, results): if self.should_skip(node, results): return # Compute the annotation annot = self.make_annotation(node, results) if annot is None: return argtypes, restype = annot if self.type_options['annotation_style'] == 'py3': self.add_py3_annot(argtypes, restype, node, results) else: self.add_py2_annot(argtypes, restype, node, results) # Common to py2 and py3 style annotations: if BaseFixAnnotate.counter is not None: BaseFixAnnotate.counter -= 1 # Also add 'from typing import Any' at the top if needed. self.patch_imports(argtypes + [restype], node) def add_py3_annot(self, argtypes, restype, node, results): # type: (List[str], str, Node, Dict[str, Any]) -> None args = results.get('args') # type: Optional[Node] argleaves = [] # type: List[Tuple[str, Leaf]] if args is None: # function with 0 arguments it = iter([]) # type: Iterator[Union[Leaf, Node]] elif len(args.children) == 0: # function with 1 argument it = iter([args]) else: # function with multiple arguments or 1 arg with default value it = iter(args.children) for ch in it: argstyle = 'name' if ch.type == token.STAR: # *arg part argstyle = 'star' ch = next(it) if ch.type == token.COMMA: continue elif ch.type == token.DOUBLESTAR: # *arg part argstyle = 'keyword' ch = next(it) assert ch.type == token.NAME assert isinstance(ch, Leaf) argleaves.append((argstyle, ch)) try: ch = next(it) if ch.type == token.EQUAL: ch = next(it) ch = next(it) assert ch.type == token.COMMA continue except StopIteration: break # when self or cls is not annotated, argleaves == argtypes+1 argleaves = argleaves[len(argleaves) - len(argtypes):] for ch_withstyle, chtype in zip(argleaves, argtypes): style, ch = ch_withstyle if style == 'star': assert chtype[0] == '*' assert chtype[1] != '*' chtype = chtype[1:] elif style == 'keyword': assert chtype[0:2] == '**' assert chtype[2] != '*' chtype = chtype[2:] ch.value = '%s: %s' % (ch.value, chtype) # put spaces around the equal sign if ch.next_sibling and ch.next_sibling.type == token.EQUAL: nextch = ch.next_sibling if not nextch.prefix[:1].isspace(): nextch.prefix = ' ' + nextch.prefix nextch_ = nextch.next_sibling assert nextch_ is not None if not nextch_.prefix[:1].isspace(): nextch_.prefix = ' ' + nextch_.prefix # Add return annotation rpar = results['rpar'] rpar.value = '%s -> %s' % (rpar.value, restype) rpar.changed() def use_py2_long_form(self, argtypes, short_str, degen_str): # type: (List[str], str, str) -> bool if self.type_options['comment_style'] == 'single': return False elif self.type_options['comment_style'] == 'multi': return False else: # auto return ((len(short_str) > 64 or len(argtypes) > 5) and len(short_str) > len(degen_str)) def add_py2_annot(self, argtypes, restype, node, results): # type: (List[str], str, Node, Dict[str, Any]) -> None children = results['suite'][0].children # Insert '# type: {annot}' comment. # For reference, see lib2to3/fixes/fix_tuple_params.py in stdlib. if len(children) >= 1 and children[0].type != token.NEWLINE: # one liner function if children[0].prefix.strip() == '': children[0].prefix = '' children.insert(0, Leaf(token.NEWLINE, '\n')) children.insert( 1, Leaf(token.INDENT, find_indentation(node) + ' ')) children.append(Leaf(token.DEDENT, '')) if len(children) >= 2 and children[1].type == token.INDENT: degen_str = '(...) -> %s' % restype short_str = '(%s) -> %s' % (', '.join(argtypes), restype) if self.use_py2_long_form(argtypes, short_str, degen_str): self.insert_long_form(node, results, argtypes) annot_str = degen_str else: annot_str = short_str indent_node = children[1] comment, sep, other_comments = indent_node.prefix.partition('\n') comment = comment.rstrip() + sep annot_str = '# type: %s\n' % (annot_str,) if comment == annot_str: return if comment and not is_type_comment(comment): # push existing non-type comment to next line annot_str += comment indent_node.prefix = indent_node.value + annot_str + other_comments indent_node.changed() else: self.log_message("%s:%d: cannot insert annotation for one-line function" % (self.filename, node.get_lineno())) def insert_long_form(self, node, results, argtypes): # type: (Node, Dict[str, Any], List[str]) -> None argtypes = list(argtypes) # We destroy it args = results['args'] if isinstance(args, Node): children = args.children elif isinstance(args, Leaf): children = [args] else: children = [] # Interpret children according to the following grammar: # (('*'|'**')? NAME ['=' expr] ','?)* flag = False # Set when the next leaf should get a type prefix indent = '' # Will be set by the first child def set_prefix(child): if argtypes: arg = argtypes.pop(0).lstrip('*') else: arg = 'Any' # Somehow there aren't enough args if not arg: # Skip self (look for 'check_self' below) prefix = child.prefix.rstrip() else: prefix = ' # type: ' + arg old_prefix = child.prefix.strip() if old_prefix: assert old_prefix.startswith('#') prefix += ' ' + old_prefix child.prefix = prefix + '\n' + indent check_self = self.is_method(node) for child in children: if isinstance(child, Leaf): if check_self and child.type == token.NAME: check_self = False if child.value in ('self', 'cls'): argtypes.insert(0, '') if not indent: indent = ' ' * child.column if child.value == ',': flag = True elif flag: set_prefix(child) flag = False need_comma = len(children) >= 1 and children[-1].type != token.COMMA if need_comma and len(children) >= 2: if (children[-1].type == token.NAME and (children[-2].type in (token.STAR, token.DOUBLESTAR))): need_comma = False if need_comma: children.append(Leaf(token.COMMA, u",")) # Find the ')' and insert a prefix before it too. parameters = args.parent close_paren = parameters.children[-1] assert close_paren.type == token.RPAR, close_paren set_prefix(close_paren) assert not argtypes, argtypes def patch_imports(self, types, node): for typ in types: if 'Any' in typ: touch_import('typing', 'Any', node) break def make_annotation(self, node, results): # type: (Node, Dict[str, Any]) -> Optional[Tuple[List[str], str]] """Return the type annotations. Given the current Note and the dictionary parsed from PATTERN return the annoations for the arguments and return types as strings """ raise NotImplementedError # The parse tree has a different shape when there is a single # decorator vs. when there are multiple decorators. DECORATED = "decorated< (d=decorator | decorators< dd=decorator+ >) funcdef >" decorated = compile_pattern(DECORATED) def get_decorators(self, node): """Return a list of decorators found on a function definition. This is a list of strings; only simple decorators (e.g. @staticmethod) are returned. If the function is undecorated or only non-simple decorators are found, return []. """ if node.parent is None: return [] results = {} if not self.decorated.match(node.parent, results): return [] decorators = results.get('dd') or [results['d']] decs = [] for d in decorators: for child in d.children: if isinstance(child, Leaf) and child.type == token.NAME: decs.append(child.value) return decs def is_method(self, node): """Return whether the node occurs (directly) inside a class.""" node = node.parent while node is not None: if node.type == syms.classdef: return True if node.type == syms.funcdef: return False node = node.parent return False RETURN_EXPR = "return_stmt< 'return' any >" return_expr = compile_pattern(RETURN_EXPR) def has_return_exprs(self, node): """Traverse the tree below node looking for 'return expr'. Return True if at least 'return expr' is found, False if not. (If both 'return' and 'return expr' are found, return True.) """ results = {} if self.return_expr.match(node, results): return True for child in node.children: if child.type not in (syms.funcdef, syms.classdef): if self.has_return_exprs(child): return True return False YIELD_EXPR = "yield_expr< 'yield' [any] >" yield_expr = compile_pattern(YIELD_EXPR) def is_generator(self, node): """Traverse the tree below node looking for 'yield [expr]'.""" results = {} if self.yield_expr.match(node, results): return True for child in node.children: if child.type not in (syms.funcdef, syms.classdef): if self.is_generator(child): return True return False
class FixMergePyi(BaseFix): """Specialized lib2to3 fixer for applying pyi annotations.""" # This fixer is compatible with the bottom matcher. BM_compatible = True # This fixer shouldn't run by default. explicit = True PATTERN = FuncSignature.PATTERN def __init__(self, options, log): super(FixMergePyi, self).__init__(options, log) # ParsedPyi obtained from .pyi file self.parsed_pyi = None # Did we add globals required by pyi to the top of the py file self.added_pyi_globals = False self.logger = logging.getLogger(self.__class__.__name__) # Options below # insert type annotations in PEP484 style. Otherwise insert as comments self._annotate_pep484 = False @property def annotate_pep484(self): return self._annotate_pep484 @annotate_pep484.setter def annotate_pep484(self, value): self._annotate_pep484 = bool(value) def transform(self, node, results): assert self.parsed_pyi, 'must provide pyi_string' src_sig = FuncSignature(node, results) if not self.can_annotate(src_sig): return pyi_sig = self.parsed_pyi.funcs[src_sig.full_name] if self.annotate_pep484: self.insert_annotation(src_sig, pyi_sig) else: annot = self.get_comment_annotation(src_sig, pyi_sig) if src_sig.try_insert_comment_annotation(annot) and 'Any' in annot: touch_import('typing', 'Any', node) self.add_globals(node) def insert_annotation(self, src_sig, pyi_sig): """Insert annotation in PEP484 format.""" for arg_sig, pyi_arg_sig in zip(src_sig.arg_sigs, pyi_sig.arg_sigs): if not pyi_arg_sig.arg_type: continue new_type = clean_clone(pyi_arg_sig.arg_type, False) arg_sig.insert_annotation(new_type) if pyi_sig.ret_type: src_sig.insert_ret_annotation(pyi_sig.ret_type) def get_comment_annotation(self, src_sig, pyi_sig): """Return function annotation as a comment string, doesn't modify tree.""" arg_types = [] for i, (arg_sig, pyi_arg_sig) in enumerate( zip(src_sig.arg_sigs, pyi_sig.arg_sigs)): is_first = (i == 0) new_type = clean_clone(pyi_arg_sig.arg_type, True) if new_type: new_type_str = str(new_type).strip() elif self.infer_should_annotate(src_sig, arg_sig, is_first): new_type_str = 'Any' else: continue arg_types.append(arg_sig.stars + new_type_str) ret_type = pyi_sig.ret_type if not ret_type: ret_type = self.infer_ret_type(src_sig) return '(' + ', '.join(arg_types) + ') -> ' + str(ret_type).strip() def can_annotate(self, src_sig): if src_sig.has_pep484_annotations or src_sig.has_comment_annotations: self.logger.warning('already annotated, skipping %s', src_sig) return False if src_sig.full_name not in self.parsed_pyi.funcs: self.logger.warning('no signature for %s, skipping', src_sig) return False pyi_sig = self.parsed_pyi.funcs[src_sig.full_name] if not pyi_sig.has_pep484_annotations: self.logger.warning('ignoring pyi definition with no annotations: %s', pyi_sig) return False if not self.func_sig_compatible(src_sig, pyi_sig): self.logger.warning('incompatible annotation, skipping %s', src_sig) return False return True def add_globals(self, node): """Add required globals to the root of node. Idempotent.""" if self.added_pyi_globals: return # TODO(tsudol): get rid of this -- added to prevent adding # .parsed_pyi.top_lines every time we annotate a different function in the # same file, but can break when we run the tool twice on the same file. Have # to do something like what touch_import does. self.added_pyi_globals = True imports, top_lines = self.parsed_pyi.imports, self.parsed_pyi.top_lines # Copy imports if not already present for pkg, names in imports: if names is None: # TODO(tsudol): do ourselves, touch_import puts stuff above license # headers. touch_import(None, pkg, node) # == 'import pkg' else: for name in names: touch_import(pkg, name, node) root = find_root(node) import_idx = [ idx for idx, node in enumerate(root.children) if self.import_pattern.match(node) ] if import_idx: insert_pos = import_idx[-1] + 1 else: insert_pos = 0 # first string (normally docstring) for idx, node in enumerate(root.children): if (node.type == syms.simple_stmt and node.children and node.children[0].type == token.STRING): insert_pos = idx + 1 break top_lines = '\n'.join(top_lines) top_lines = Util.parse_string(top_lines) # strips some newlines for offset, node in enumerate(top_lines.children[:-1]): root.insert_child(insert_pos + offset, node) @staticmethod def func_sig_compatible(src_sig, pyi_sig): """Can src_sig be annotated with the info in pyi_sig? For the two signatures to be compatible, the number of arguments must match, they must have the same star args and they can't be tuple arguments. Args: src_sig: A FuncSignature representing the .py signature. pyi_sig: A FuncSignature representing the .pyi signature. Returns: True if the two signatures are compatible, False otherwise. """ if len(pyi_sig.arg_sigs) != len(src_sig.arg_sigs): return False for pyi, cur in zip(pyi_sig.arg_sigs, src_sig.arg_sigs): # Entirely skip functions that use tuple args if cur.is_tuple or pyi.is_tuple: return False # Stars are expected to match if cur.stars != pyi.stars: return False return True @staticmethod def infer_ret_type(src_sig): """Heuristic for return type of a function.""" if src_sig.short_name == '__init__' or not src_sig.has_return_exprs: return 'None' return 'Any' @staticmethod def infer_should_annotate(func, arg, at_start): """Heuristic for whether arg, in func, should be annotated.""" if func.is_method and at_start and 'staticmethod' not in func.decorators: # Don't annotate the first argument if it's named 'self'. # Don't annotate the first argument of a class method. if 'self' == arg.name or 'classmethod' in func.decorators: return False return True def set_pyi_string(self, pyi_string): """Set the annotations the fixer will use.""" self.parsed_pyi = self.parse_pyi_string(pyi_string) self.added_pyi_globals = False def parse_pyi_string(self, text): """Parse .pyi string, return as ParsedPyi.""" tree = Util.parse_string(text) funcs = {} for node, match_results in generate_matches(tree, self.pattern): sig = FuncSignature(node, match_results) if sig.full_name in funcs: self.logger.warning('Ignoring redefinition: %s', sig) else: funcs[sig.full_name] = sig imports = [] for node, match_results in generate_top_matches(tree, self.import_pattern): imp = self.parse_top_import(node, match_results) if imp: imports.append(imp) top_lines = [] for node, match_results in generate_top_matches(tree, self.assign_pattern): text = str(node).strip() # hack to avoid shadowing real variables -- proper solution is more # complicated, use util.find_binding if 'TypeVar' in text or (text and '_' == text[0]): top_lines.append(text) else: self.logger.warning('ignoring %s', repr(text)) return ParsedPyi(tuple(imports), top_lines, funcs) assign_pattern = compile_pattern(""" simple_stmt< expr_stmt<any+> any* > """) import_pattern = compile_pattern(""" simple_stmt< ( import_from< 'from' pkg=any+ 'import' ['('] names=any [')'] > | import_name< 'import' pkg=any+ > ) any* > """) import_as_pattern = compile_pattern("""import_as_name<NAME 'as' NAME>""") def parse_top_import(self, node, results): """Splits the result of import_pattern into component strings. Examples: 'from pkg import a,b,c' gives ('pkg', ('a', 'b', 'c')) 'import pkg' gives ('pkg', None) 'from pkg import a as b' or 'import pkg as pkg2' are not supported. Args: node: The import statement node. results: The values from import_pattern. Returns: A tuple of the package name (string) and the list of imported names (list of strings). """ # TODO(tsudol): this might have to be generalized to "get top-level # statements that aren't class or function definitions": # _T = typing.TypeVar('_T') is used in pyis. # Still not clear what is and isn't valid in a pyi... Could we have a loop? pkg, names = results['pkg'], results.get('names', None) pkg = ''.join(map(str, pkg)).strip() if names: is_import_as = any( True for _ in generate_matches(names, self.import_as_pattern)) if is_import_as: # fixer_util.touch_import doesn't handle this # If necessary, will have to stick import at top of .py file self.logger.warning('Ignoring unhandled import-as: %s', repr(str(node).strip())) return None names = split_comma(names.leaves()) for name in names: assert 1 == len(name) assert name[0].type in (token.NAME, token.STAR) names = [name[0].value for name in names] return pkg, names
'print' into 'print()' 'print ...' into 'print(...)' 'print ... ,' into 'print(..., end=" ")' 'print >>x, ...' into 'print(..., file=x)' No changes are applied if print_function is imported from __future__ """ # Local imports from lib2to3 import patcomp, pytree, fixer_base from lib2to3.pgen2 import token from lib2to3.fixer_util import Name, Call, Comma, String from libmodernize import add_future parend_expr = patcomp.compile_pattern("""atom< '(' [atom|STRING|NAME] ')' >""") class FixPrint(fixer_base.BaseFix): BM_compatible = True PATTERN = """ simple_stmt< any* bare='print' any* > | print_stmt """ def transform(self, node, results): assert results bare_print = results.get("bare")
def build_pattern(mapping=MAPPING): PATTERN = """ power< local=(%s) tail=any* > """ return PATTERN % "|".join("'%s'" % key.split(".")[-1] for key in mapping.keys()) ## Pattern for finding from module import * from_import_pattern = """import_from<'from' module=(%s) 'import' star='*'>""" module_names = set(["'%s'" % key.split(".", 1)[0] for key in MAPPING.keys()]) from_import_pattern = from_import_pattern % "|".join(module_names) from_import_expr = patcomp.compile_pattern(from_import_pattern) class FixChangedNamesAggressive(FixChangedNames): mapping = MAPPING run_order = 2 def compile_pattern(self): # We override this, so MAPPING can be pragmatically altered and the # changes will be reflected in PATTERN. self.PATTERN = build_pattern(self.mapping) name2mod = defaultdict(set) for key in self.mapping.keys(): mod, name = key.split(".", 1) name2mod[name].add(mod) self._names_to_modules = name2mod
'print >>x, ...' into 'print(..., file=x)' No changes are applied if print_function is imported from __future__ """ from __future__ import unicode_literals from lib2to3 import patcomp, pytree, fixer_base from lib2to3.pgen2 import token from lib2to3.fixer_util import Name, Call, Comma, FromImport, Newline, String from libmodernize import check_future_import parend_expr = patcomp.compile_pattern( """atom< '(' [arith_expr|atom|power|term|STRING|NAME] ')' >""" ) class FixPrint(fixer_base.BaseFix): BM_compatible = True PATTERN = """ simple_stmt< any* bare='print' any* > | print_stmt """ def start_tree(self, tree, filename): self.found_print = False def transform(self, node, results):
"print(...)" not changed "print ... ," into "print(..., end=' ')" "print >>x, ..." into "print(..., file=x)" No changes are applied if print_function is imported from __future__ """ # Local imports from lib2to3 import patcomp, pytree, fixer_base from lib2to3.pgen2 import token from lib2to3.fixer_util import Name, Call, Comma, String # from libmodernize import add_future parend_expr = patcomp.compile_pattern( """atom< '(' [arith_expr|atom|power|term|STRING|NAME] ')' >""") class FixPrint(fixer_base.BaseFix): BM_compatible = True PATTERN = """ simple_stmt< any* bare='print' any* > | print_stmt """ def transform(self, node, results): assert results bare_print = results.get("bare")
class FixAnnotate(BaseFix): # This fixer is compatible with the bottom matcher. BM_compatible = True # This fixer shouldn't run by default. explicit = True # The pattern to match. PATTERN = """ funcdef< 'def' name=any parameters< '(' [args=any] ')' > ':' suite=any+ > """ counter = None if not os.getenv('MAXFIXES') else int(os.getenv('MAXFIXES')) def transform(self, node, results): if FixAnnotate.counter is not None: if FixAnnotate.counter <= 0: return suite = results['suite'] children = suite[0].children # NOTE: I've reverse-engineered the structure of the parse tree. # It's always a list of nodes, the first of which contains the # entire suite. Its children seem to be: # # [0] NEWLINE # [1] INDENT # [2...n-2] statements (the first may be a docstring) # [n-1] DEDENT # # Comments before the suite are part of the INDENT's prefix. # # "Compact" functions (e.g. "def foo(x, y): return max(x, y)") # have a different structure that isn't matched by PATTERN. ## print('-'*60) ## print(node) ## for i, ch in enumerate(children): ## print(i, repr(ch.prefix), repr(ch)) # Check if there's already an annotation. for ch in children: if ch.prefix.lstrip().startswith('# type:'): return # There's already a # type: comment here; don't change anything. # Compute the annotation annot = self.make_annotation(node, results) if annot is None: return # Insert '# type: {annot}' comment. # For reference, see lib2to3/fixes/fix_tuple_params.py in stdlib. if len(children) >= 2 and children[1].type == token.INDENT: argtypes, restype = annot degen_str = '(...) -> %s' % restype short_str = '(%s) -> %s' % (', '.join(argtypes), restype) if (len(short_str) > 64 or len(argtypes) > 5) and len(short_str) > len(degen_str): self.insert_long_form(node, results, argtypes) annot_str = degen_str else: annot_str = short_str children[1].prefix = '%s# type: %s\n%s' % (children[1].value, annot_str, children[1].prefix) children[1].changed() if FixAnnotate.counter is not None: FixAnnotate.counter -= 1 # Also add 'from typing import Any' at the top if needed. self.patch_imports(argtypes + [restype], node) def insert_long_form(self, node, results, argtypes): argtypes = list(argtypes) # We destroy it args = results['args'] if isinstance(args, Node): children = args.children elif isinstance(args, Leaf): children = [args] else: children = [] # Interpret children according to the following grammar: # (('*'|'**')? NAME ['=' expr] ','?)* flag = False # Set when the next leaf should get a type prefix indent = '' # Will be set by the first child def set_prefix(child): if argtypes: arg = argtypes.pop(0).lstrip('*') else: arg = 'Any' # Somehow there aren't enough args if not arg: # Skip self (look for 'check_self' below) prefix = child.prefix.rstrip() else: prefix = ' # type: ' + arg old_prefix = child.prefix.strip() if old_prefix: assert old_prefix.startswith('#') prefix += ' ' + old_prefix child.prefix = prefix + '\n' + indent check_self = self.is_method(node) for child in children: if check_self and isinstance(child, Leaf) and child.type == token.NAME: check_self = False if child.value in ('self', 'cls'): argtypes.insert(0, '') if not indent: indent = ' ' * child.column if isinstance(child, Leaf) and child.value == ',': flag = True elif isinstance(child, Leaf) and flag: set_prefix(child) flag = False # Find the ')' and insert a prefix before it too. parameters = args.parent close_paren = parameters.children[-1] assert close_paren.type == token.RPAR, close_paren set_prefix(close_paren) assert not argtypes, argtypes def patch_imports(self, types, node): for typ in types: if 'Any' in typ: touch_import('typing', 'Any', node) break def make_annotation(self, node, results): name = results['name'] assert isinstance(name, Leaf), repr(name) assert name.type == token.NAME, repr(name) decorators = self.get_decorators(node) is_method = self.is_method(node) if name.value == '__init__' or not self.has_return_exprs(node): restype = 'None' else: restype = 'Any' args = results.get('args') argtypes = [] if isinstance(args, Node): children = args.children elif isinstance(args, Leaf): children = [args] else: children = [] # Interpret children according to the following grammar: # (('*'|'**')? NAME ['=' expr] ','?)* stars = inferred_type = '' in_default = False at_start = True for child in children: if isinstance(child, Leaf): if child.value in ('*', '**'): stars += child.value elif child.type == token.NAME and not in_default: if not is_method or not at_start or 'staticmethod' in decorators: inferred_type = 'Any' else: # Always skip the first argument if it's named 'self'. # Always skip the first argument of a class method. if child.value == 'self' or 'classmethod' in decorators: pass else: inferred_type = 'Any' elif child.value == '=': in_default = True elif in_default and child.value != ',': if child.type == token.NUMBER: if re.match(r'\d+[lL]?$', child.value): inferred_type = 'int' else: inferred_type = 'float' # TODO: complex? elif child.type == token.STRING: if child.value.startswith(('u', 'U')): inferred_type = 'unicode' else: inferred_type = 'str' elif child.type == token.NAME and child.value in ('True', 'False'): inferred_type = 'bool' elif child.value == ',': if inferred_type: argtypes.append(stars + inferred_type) # Reset stars = inferred_type = '' in_default = False at_start = False if inferred_type: argtypes.append(stars + inferred_type) return argtypes, restype # The parse tree has a different shape when there is a single # decorator vs. when there are multiple decorators. DECORATED = "decorated< (d=decorator | decorators< dd=decorator+ >) funcdef >" decorated = compile_pattern(DECORATED) def get_decorators(self, node): """Return a list of decorators found on a function definition. This is a list of strings; only simple decorators (e.g. @staticmethod) are returned. If the function is undecorated or only non-simple decorators are found, return []. """ if node.parent is None: return [] results = {} if not self.decorated.match(node.parent, results): return [] decorators = results.get('dd') or [results['d']] decs = [] for d in decorators: for child in d.children: if isinstance(child, Leaf) and child.type == token.NAME: decs.append(child.value) return decs def is_method(self, node): """Return whether the node occurs (directly) inside a class.""" node = node.parent while node is not None: if node.type == syms.classdef: return True if node.type == syms.funcdef: return False node = node.parent return False RETURN_EXPR = "return_stmt< 'return' any >" return_expr = compile_pattern(RETURN_EXPR) def has_return_exprs(self, node): """Traverse the tree below node looking for 'return expr'. Return True if at least 'return expr' is found, False if not. (If both 'return' and 'return expr' are found, return True.) """ results = {} if self.return_expr.match(node, results): return True for child in node.children: if child.type not in (syms.funcdef, syms.classdef): if self.has_return_exprs(child): return True return False YIELD_EXPR = "yield_expr< 'yield' [any] >" yield_expr = compile_pattern(YIELD_EXPR) def is_generator(self, node): """Traverse the tree below node looking for 'yield [expr]'.""" results = {} if self.yield_expr.match(node, results): return True for child in node.children: if child.type not in (syms.funcdef, syms.classdef): if self.is_generator(child): return True return False
class Pyi(collections.namedtuple('Pyi', 'imports assignments funcs')): """A parsed pyi.""" def _get_imports(self, inserted_types): """Get the imports that provide the given types.""" used_names = set() for node in inserted_types + self.assignments: for leaf in node.leaves(): if leaf.type == token.NAME: used_names.add(leaf.value) # All prefixes are possible imports. while '.' in leaf.value: value, _ = leaf.rsplit('.', 1) used_names.add(value) for (pkg, pkg_alias), names in self.imports: if not names: if (pkg_alias or pkg) in used_names: yield ((pkg, pkg_alias), names) else: names = [(name, alias) for name, alias in names if name == '*' or (alias or name) in used_names] if names: yield ((pkg, pkg_alias), names) def add_globals(self, tree, inserted_types): """Add required globals to the tree. Idempotent.""" # Copy imports if not already present top_lines = [] def import_name(name, alias): return name + ('' if alias is None else ' as %s' % alias) for (pkg, pkg_alias), names in self._get_imports(inserted_types): if not names: if does_tree_import(None, pkg_alias or pkg, tree): continue top_lines.append('import %s\n' % import_name(pkg, pkg_alias)) else: assert pkg_alias is None import_names = [] for name, alias in names: if does_tree_import(pkg, alias or name, tree): continue import_names.append(import_name(name, alias)) if not import_names: continue top_lines.append('from %s import %s\n' % (pkg, ', '.join(import_names))) import_idx = [ idx for idx, idx_node in enumerate(tree.children) if self.import_pattern.match(idx_node) ] if import_idx: insert_pos = import_idx[-1] + 1 else: insert_pos = 0 # first string (normally docstring) for idx, idx_node in enumerate(tree.children): if (idx_node.type == syms.simple_stmt and idx_node.children and idx_node.children[0].type == token.STRING): insert_pos = idx + 1 break if self.assignments: top_lines.append('\n') top_lines.extend(str(a).strip() + '\n' for a in self.assignments) top_lines = Util.parse_string(''.join(top_lines)) for offset, offset_node in enumerate(top_lines.children[:-1]): tree.insert_child(insert_pos + offset, offset_node) @classmethod def _log_warning(cls, *args): logger = logging.getLogger(cls.__name__) logger.warning(*args) @classmethod def parse(cls, text): """Parse .pyi string, return as Pyi.""" tree = Util.parse_string(text) funcs = {} for node, match_results in generate_matches(tree, cls.function_pattern): sig = FuncSignature(node, match_results) if sig.full_name in funcs: cls._log_warning('Ignoring redefinition: %s', sig) else: funcs[sig.full_name] = sig imports = [] # Any is sometimes inserted as a default type, so make sure typing.Any is # always importable. any_import = False for node, match_results in generate_top_matches( tree, cls.import_pattern): pkg, names = cls.parse_top_import(match_results) if pkg == ('typing', None) and names: if ('Any', None) not in names: names.insert(0, ('Any', None)) any_import = True imports.append((pkg, names)) if not any_import: imports.append((('typing', None), [('Any', None)])) assignments = [] for node, match_results in generate_top_matches( tree, cls.assign_pattern): text = str(node) # hack to avoid shadowing real variables -- proper solution is more # complicated, use util.find_binding if 'TypeVar' in text or (text and text[0] == '_'): assignments.append(node) else: cls._log_warning('ignoring %s', repr(text)) return cls(tuple(imports), tuple(assignments), funcs) function_pattern = compile_pattern(FuncSignature.PATTERN) assign_pattern = compile_pattern(""" simple_stmt< expr_stmt<any+> any* > """) import_pattern = compile_pattern(""" simple_stmt< ( import_from< 'from' pkg=any+ 'import' ['('] names=any [')'] > | import_name< 'import' pkg=any+ > ) any* > """) @classmethod def _parse_import_alias(cls, leaves): assert leaves[-2].value == 'as' name = ''.join(leaf.value for leaf in leaves[:-2]) return (name, leaves[-1].value) @classmethod def parse_top_import(cls, results): """Splits the result of import_pattern into component strings. Examples: 'from pkg import a,b,c' gives (('pkg', None), [('a', None), ('b', None), ('c', None)]) 'import pkg' gives (('pkg', None), []) 'from pkg import a as b' gives (('pkg', None), [('a', 'b')]) 'import pkg as pkg2' gives (('pkg', 'pkg2'), []) 'import pkg.a as b' gives (('pkg.a', 'b'), []) Args: results: The values from import_pattern. Returns: A tuple of the package name and the list of imported names. Each name is a tuple of original name and alias. """ pkg, names = results['pkg'], results.get('names', None) if len(pkg ) == 1 and pkg[0].type == pygram.python_symbols.dotted_as_name: pkg_out = cls._parse_import_alias(list(pkg[0].leaves())) else: pkg_out = (''.join(map(str, pkg)).strip(), None) names_out = [] if names: names = split_comma(names.leaves()) for name in names: if len(name) == 1: assert name[0].type in (token.NAME, token.STAR) names_out.append((name[0].value, None)) else: names_out.append(cls._parse_import_alias(name)) return pkg_out, names_out
'print ...' into 'print(...)' 'print ... ,' into 'print(..., end=" ")' 'print >>x, ...' into 'print(..., file=x)' No changes are applied if print_function is imported from __future__ """ # Local imports from lib2to3 import patcomp, pytree, fixer_base from lib2to3.pgen2 import token from lib2to3.fixer_util import Name, Call, Comma, String from libmodernize import add_future parend_expr = patcomp.compile_pattern( """atom< '(' [atom|STRING|NAME] ')' >""" ) class FixPrint(fixer_base.BaseFix): BM_compatible = True PATTERN = """ simple_stmt< any* bare='print' any* > | print_stmt """ def transform(self, node, results): assert results bare_print = results.get("bare")
class FixDict(fixer_base.BaseFix): BM_compatible = True PATTERN = """ power< head=any+ trailer< '.' method=('keys'|'items'|'values'| 'iterkeys'|'iteritems'|'itervalues'| 'viewkeys'|'viewitems'|'viewvalues') > parens=trailer< '(' ')' > tail=any* > """ def transform(self, node, results): head = results["head"] method = results["method"][0] # Extract node for method name tail = results["tail"] syms = self.syms method_name = method.value isiter = method_name.startswith(u"iter") isview = method_name.startswith(u"view") head = [n.clone() for n in head] tail = [n.clone() for n in tail] # no changes neccessary if the call is in a special context special = not tail and self.in_special_context(node, isiter) new = pytree.Node(syms.power, head) new.prefix = u"" if isiter or isview: # replace the method with the six function # e.g. d.iteritems() -> from six import iteritems\n iteritems(d) new = Call(Name(method_name), [new]) touch_import_top('six', method_name, node) elif special: # it is not neccessary to change this case return node elif method_name in ("items", "values"): # ensure to return a list in python 3 new = Call(Name(u"list" + method_name), [new]) touch_import_top('future.utils', 'list' + method_name, node) else: # method_name is "keys"; removed it and cast the dict to list new = Call(Name(u"list"), [new]) if tail: new = pytree.Node(syms.power, [new] + tail) new.prefix = node.prefix return new P1 = "power< func=NAME trailer< '(' node=any ')' > any* >" p1 = patcomp.compile_pattern(P1) def in_special_context(self, node, isiter): # it is not wrapped if node.parent is None: return False results = {} if (node.parent.parent is not None and self.p1.match(node.parent.parent, results) and results["node"] is node): if isiter: # iter(d.iterkeys()) -> iter(d.keys()), etc. return results["func"].value in iter_exempt else: # list(d.keys()) -> list(d.keys()), etc. return results["func"].value in fixer_util.consuming_calls return False