Beispiel #1
0
 def register_checker(pattern, checker, extra):
     if "python_minimum_version" in extra and sys.version_info < extra["python_minimum_version"]:
         return
     if "python_disabled_version" in extra and sys.version_info > extra["python_disabled_version"]:
         return
     pattern = patcomp.compile_pattern(pattern)
     collected_checkers.append((pattern, checker, extra))
Beispiel #2
0
class Util(object):

    return_expr = compile_pattern("""return_stmt< 'return' any >""")

    @classmethod
    def has_return_exprs(cls, node):
        """Traverse the tree below node looking for 'return expr'.

        Return True if at least 'return expr' is found, False if not.
        (If both 'return' and 'return expr' are found, return True.)
        """
        results = {}
        if cls.return_expr.match(node, results):
            return True
        for child in node.children:
            if child.type not in (syms.funcdef, syms.classdef):
                if cls.has_return_exprs(child):
                    return True
        return False

    driver = driver.Driver(pygram.python_grammar, convert=pytree.convert)

    @classmethod
    def parse_string(cls, text):
        """Use lib2to3 to parse text into a Node."""

        text = text.strip()
        if not text:
            # self.driver.parse_string just returns the ENDMARKER Leaf, wrap in a Node
            # for consistency
            return Node(syms.file_input, [Leaf(token.ENDMARKER, '')])

        # workaround: parsing text without trailing '\n' throws exception
        text += '\n'
        return cls.driver.parse_string(text)
Beispiel #3
0
 def register_checker(pattern, checker, extra):
     if ('python_minimum_version' in extra
             and sys.version_info < extra['python_minimum_version']):
         return
     if ('python_disabled_version' in extra
             and sys.version_info > extra['python_disabled_version']):
         return
     pattern = patcomp.compile_pattern(pattern)
     collected_checkers.append((pattern, checker, extra))
Beispiel #4
0
    def _RegisterPattern(self, method, pattern):
        """
        Registers a new pattern and the handling method.

        :param str method:
            The method name to handle the node matching the pattern.

        :param str pattern:
            The pattern to match AST nodes.
            This follows the lib2to3 pattern syntax.
        """
        from lib2to3.patcomp import compile_pattern
        self.patterns.append((method, compile_pattern(pattern)))
Beispiel #5
0
    def _RegisterPattern(self, method, pattern):
        """
        Registers a new pattern and the handling method.

        :param str method:
            The method name to handle the node matching the pattern.

        :param str pattern:
            The pattern to match AST nodes.
            This follows the lib2to3 pattern syntax.
        """
        from lib2to3.patcomp import compile_pattern
        self.patterns.append((method, compile_pattern(pattern)))
Beispiel #6
0
class Util:
    """Utility functions for working with Nodes."""

    return_expr = compile_pattern("""return_stmt< 'return' any >""")

    @classmethod
    def has_return_exprs(cls, node):
        """Traverse the tree below node looking for 'return expr'.

    Args:
      node: The AST node at the root of the subtree.

    Returns:
      True if 'return' or 'return expr' is found, False otherwise.
    """
        results = {}
        if cls.return_expr.match(node, results):
            return True
        for child in node.children:
            if child.type not in (syms.funcdef, syms.classdef):
                if cls.has_return_exprs(child):
                    return True
        return False

    driver = driver.Driver(pygram.python_grammar, convert=pytree.convert)

    @classmethod
    def parse_string(cls, text):
        """Use lib2to3 to parse text into a Node."""

        text = text.strip()
        if not text:
            # cls.driver.parse_string just returns the ENDMARKER Leaf, wrap in
            # a Node for consistency
            return Node(syms.file_input, [Leaf(token.ENDMARKER, '')])

        # workaround: parsing text without trailing '\n' throws exception
        text += '\n'
        return cls.driver.parse_string(text)
Beispiel #7
0
 def register_pattern(self, method, pattern):
     """Register method to handle given pattern.
     """
     self.patterns.append((method, compile_pattern(pattern)))
Beispiel #8
0
class FixAnnotate(BaseFix):

    # This fixer is compatible with the bottom matcher.
    BM_compatible = True

    # This fixer shouldn't run by default.
    explicit = True

    PATTERN = FuncSignature.PATTERN

    counter = None if not os.getenv('MAXFIXES') else int(os.getenv('MAXFIXES'))

    def __init__(self, options, log):
        super(FixAnnotate, self).__init__(options, log)

        # ParsedPyi obtained from .pyi file
        self.parsed_pyi = None

        # Did we add globals required by pyi to the top of the py file
        self.added_pyi_globals = False

        self.logger = logging.getLogger('FixAnnotate')

        # Options below

        # List of things to import from "__future__"
        self.future_imports = tuple()

        # insert type annotations in PEP484 style. Otherwise insert as comments
        self._annotate_pep484 = False

        # Strip comments and, formatting from type annotations (False breaks comment output mode)
        self._strip_pyi_formatting = not self.annotate_pep484

    @property
    def annotate_pep484(self):
        return self._annotate_pep484

    @annotate_pep484.setter
    def annotate_pep484(self, value):
        self._annotate_pep484 = bool(value)
        self._strip_pyi_formatting = not self.annotate_pep484

    def transform(self, node, results):
        assert self.parsed_pyi, 'must provide pyi_string'

        if FixAnnotate.counter is not None:
            if FixAnnotate.counter <= 0:
                return

        cur_sig = FuncSignature(node, results)
        if not self.can_annotate(cur_sig):
            return

        if FixAnnotate.counter is not None:
            FixAnnotate.counter -= 1

        # Compute the annotation, or directly insert if not self.emit_as_comment
        annot = self.get_or_insert_annotation(cur_sig)

        if not self.annotate_pep484 and annot:
            if cur_sig.try_insert_comment_annotation(annot) and 'Any' in annot:
                touch_import('typing', 'Any', node)

        self.add_globals(node)

    def get_or_insert_annotation(self, cur_sig):
        """If self.annotate_pep484, insert, otherwise return as comment string"""
        arg_types = []
        for i, arg_sig in enumerate(cur_sig.arg_sigs):
            pyi_sig = self.parsed_pyi.funcs[cur_sig.full_name]
            new_type = pyi_sig.arg_sigs[i].arg_type
            new_type = clean_clone(new_type, self._strip_pyi_formatting)

            if self.annotate_pep484:
                if new_type:
                    arg_sig.insert_annotation(new_type)
            else:
                is_first = (i == 0)

                if new_type:
                    arg_types.append(arg_sig.stars + str(new_type).strip())
                elif self.infer_should_annotate(cur_sig, arg_sig, is_first):
                    arg_types.append(arg_sig.stars + 'Any')

        pyi_sig = self.parsed_pyi.funcs[cur_sig.full_name]
        ret_type = pyi_sig.ret_type

        if not self.annotate_pep484:
            if not ret_type:
                ret_type = self.infer_ret_type(cur_sig)

            return '(' + ', '.join(arg_types) + ') -> ' + str(ret_type).strip()
        elif ret_type:
            cur_sig.insert_ret_annotation(ret_type)

    def can_annotate(self, cur_sig):
        if cur_sig.has_pep484_annotations or cur_sig.has_comment_annotations:
            self.logger.warning('already annotated, skipping %s', cur_sig)
            return False

        if cur_sig.full_name not in self.parsed_pyi.funcs:
            self.logger.warning('no signature for %s, skipping', cur_sig)
            return False

        pyi_sig = self.parsed_pyi.funcs[cur_sig.full_name]

        if not pyi_sig.has_pep484_annotations:
            self.logger.warning(
                'ignoring pyi definition with no annotations: %s', pyi_sig)
            return False

        if not self.func_sig_compatible(cur_sig, pyi_sig):
            self.logger.warning('incompatible annotation, skipping %s',
                                cur_sig)
            return False

        return True

    def add_globals(self, node):
        """Add required globals to the root of node. Idempotent."""
        if self.added_pyi_globals:
            return
        # TODO: get rid of this -- added to prevent adding .parsed_pyi.top_lines every time
        # we annotate a different function in the same file, but can break when we run the tool
        # twice on the same file. Have to do something like what touch_import does.
        self.added_pyi_globals = True

        imports, top_lines = self.parsed_pyi.imports, self.parsed_pyi.top_lines

        # Copy imports if not already present
        for pkg, names in imports:
            if names is None:
                # TODO: do ourselves, touch_import puts stuff above license headers
                touch_import(None, pkg, node)  # == 'import pkg'
            else:
                for name in names:
                    touch_import(pkg, name, node)

        root = find_root(node)

        import_idx = [
            idx for idx, node in enumerate(root.children)
            if self.import_pattern.match(node)
        ]
        if import_idx:
            future_insert_pos = import_idx[0]
            top_insert_pos = import_idx[-1] + 1
        else:
            future_insert_pos = top_insert_pos = 0

            # first string (normally docstring)
            for idx, node in enumerate(root.children):
                if (node.type == syms.simple_stmt and node.children
                        and node.children[0].type == token.STRING):
                    future_insert_pos = top_insert_pos = idx + 1
                    break

        top_lines = '\n'.join(top_lines)
        top_lines = Util.parse_string(top_lines)  # strips some newlines
        for offset, node in enumerate(top_lines.children[:-1]):
            root.insert_child(top_insert_pos + offset, node)

        # touch_import doesn't do proper order for __future__
        pkg = '__future__'
        future_imports = [
            n for n in self.future_imports
            if not does_tree_import(pkg, n, root)
        ]
        for offset, name in enumerate(future_imports):
            node = FromImport(pkg, [Leaf(token.NAME, name, prefix=" ")])
            node = Node(syms.simple_stmt, [node, Newline()])
            root.insert_child(future_insert_pos + offset, node)

    @staticmethod
    def func_sig_compatible(cur_sig, pyi_sig):
        """Can cur_sig be annotated with the info in pyi_sig: number of arguments must match,
        they must have the same star signature and they can't be tuple arguments.
        """

        if len(pyi_sig.arg_sigs) != len(cur_sig.arg_sigs):
            return False

        for pyi, cur in zip(pyi_sig.arg_sigs, cur_sig.arg_sigs):
            # Entirely skip functions that use tuple args
            if cur.is_tuple or pyi.is_tuple:
                return False

            # Stars are expected to match
            if cur.stars != pyi.stars:
                return False

        return True

    @staticmethod
    def infer_ret_type(cur_sig):
        """Heuristic for return value of a function."""
        if cur_sig.short_name == '__init__' or not cur_sig.has_return_exprs:
            return 'None'
        return 'Any'

    @staticmethod
    def infer_should_annotate(func, arg, at_start):
        """Heuristic for whether arg (in func) should be annotated."""

        if func.is_method and at_start and 'staticmethod' not in func.decorators:
            # Don't annotate the first argument if it's named 'self'.
            # Don't annotate the first argument of a class method.
            if 'self' == arg.name or 'classmethod' in func.decorators:
                return False

        return True

    def set_pyi_string(self, pyi_string):
        """Set the annotations the fixer will use"""
        self.parsed_pyi = self.parse_pyi_string(pyi_string)
        self.added_pyi_globals = False

    def parse_pyi_string(self, text):
        """Parse .pyi string, return as ParsedPyi"""
        tree = Util.parse_string(text)

        funcs = {}
        for node, match_results in generate_matches(tree, self.pattern):
            sig = FuncSignature(node, match_results)

            if sig.full_name in funcs:
                self.logger.warning('Ignoring redefinition: %s', sig)
            else:
                funcs[sig.full_name] = sig

        imports = []
        for node, match_results in generate_top_matches(
                tree, self.import_pattern):
            imp = self.parse_top_import(node, match_results)
            if imp:
                imports.append(imp)

        top_lines = []
        for node, match_results in generate_top_matches(
                tree, self.assign_pattern):
            text = str(node).strip()

            # hack to avoid shadowing real variables -- proper solution is more complicated,
            # use util.find_binding
            if 'TypeVar' in text or (text and '_' == text[0]):
                top_lines.append(text)
            else:
                self.logger.warning("ignoring %s", repr(text))

        return ParsedPyi(tuple(imports), top_lines, funcs)

    assign_pattern = compile_pattern("""
    simple_stmt< expr_stmt<any+> any* >
    """)

    import_pattern = compile_pattern("""
    simple_stmt<
        ( import_from< 'from' pkg=any+ 'import' ['('] names=any [')'] > |
          import_name< 'import' pkg=any+ > )
        any*
    >
    """)
    import_as_pattern = compile_pattern("""import_as_name<NAME 'as' NAME>""")

    def parse_top_import(self, node, results):
        """Takes result of import_pattern, returns component strings:

        Examples:

        'from pkg import a,b,c' gives
        ('pkg', ('a', 'b', 'c'))

        'import pkg' gives
        ('pkg', None)

        'from pkg import a as b' or 'import pkg as pkg2' are not supported.
        """

        # TODO: this might have to be generalized to "get top-level statements that aren't
        # class or function definitions":
        # _T = typing.TypeVar('_T') is used in pyis.
        # Still not clear what is and isn't valid in a pyi... Could we have a loop?

        pkg, names = results['pkg'], results.get('names', None)
        pkg = ''.join(map(str, pkg)).strip()

        if names:
            is_import_as = any(
                True for _ in generate_matches(names, self.import_as_pattern))

            if is_import_as:
                # fixer_util.touch_import doesn't handle this
                # If necessary, will have to stick import at top of .py file
                self.logger.warning('Ignoring unhandled import-as: %s',
                                    repr(str(node).strip()))
                return None

            names = split_comma(names.leaves())
            for name in names:
                assert 1 == len(name)
                assert name[0].type in (token.NAME, token.STAR)
            names = [name[0].value for name in names]

        return pkg, names
Beispiel #9
0
class FixAnnotate(BaseFix):

    # This fixer is compatible with the bottom matcher.
    BM_compatible = True

    # This fixer shouldn't run by default.
    explicit = True

    # The pattern to match.
    PATTERN = """
              funcdef< 'def' name=any parameters< '(' [args=any] ')' > ':' suite=any+ >
              """

    counter = None if not os.getenv('MAXFIXES') else int(os.getenv('MAXFIXES'))

    def transform(self, node, results):
        if FixAnnotate.counter is not None:
            if FixAnnotate.counter <= 0:
                return
        suite = results['suite']
        children = suite[0].children

        # NOTE: I've reverse-engineered the structure of the parse tree.
        # It's always a list of nodes, the first of which contains the
        # entire suite.  Its children seem to be:
        #
        #   [0] NEWLINE
        #   [1] INDENT
        #   [2...n-2] statements (the first may be a docstring)
        #   [n-1] DEDENT
        #
        # Comments before the suite are part of the INDENT's prefix.
        #
        # "Compact" functions (e.g. "def foo(x, y): return max(x, y)")
        # have a different structure that isn't matched by PATTERN.

        ## print('-'*60)
        ## print(node)
        ## for i, ch in enumerate(children):
        ##     print(i, repr(ch.prefix), repr(ch))

        # Check if there's already an annotation.
        for ch in children:
            if ch.prefix.lstrip().startswith('# type:'):
                return  # There's already a # type: comment here; don't change anything.

        # Compute the annotation
        annot = self.make_annotation(node, results)

        # Insert '# type: {annot}' comment.
        # For reference, see lib2to3/fixes/fix_tuple_params.py in stdlib.
        if len(children) >= 2 and children[1].type == token.INDENT:
            children[1].prefix = '{}# type: {}\n{}'.format(children[1].value, annot, children[1].prefix)
            children[1].changed()
            if FixAnnotate.counter is not None:
                FixAnnotate.counter -= 1

        # Also add 'from typing import Any' at the top.
        if 'Any' in annot:
            touch_import('typing', 'Any', node)

    def make_annotation(self, node, results):
        name = results['name']
        assert isinstance(name, Leaf), repr(name)
        assert name.type == token.NAME, repr(name)
        decorators = self.get_decorators(node)
        is_method = self.is_method(node)
        if name.value == '__init__' or not self.has_return_exprs(node):
            restype = 'None'
        else:
            restype = 'Any'
        args = results.get('args')
        argtypes = []
        if isinstance(args, Node):
            children = args.children
        elif isinstance(args, Leaf):
            children = [args]
        else:
            children = []
        # Interpret children according to the following grammar:
        # (('*'|'**')? NAME ['=' expr] ','?)*
        stars = inferred_type = ''
        in_default = False
        at_start = True
        for child in children:
            if isinstance(child, Leaf):
                if child.value in ('*', '**'):
                    stars += child.value
                elif child.type == token.NAME and not in_default:
                    if not is_method or not at_start or 'staticmethod' in decorators:
                        inferred_type = 'Any'
                    else:
                        # Always skip the first argument if it's named 'self'.
                        # Always skip the first argument of a class method.
                        if  child.value == 'self' or 'classmethod' in decorators:
                            pass
                        else:
                            inferred_type = 'Any'
                elif child.value == '=':
                    in_default = True
                elif in_default and child.value != ',':
                    if child.type == token.NUMBER:
                        if re.match(r'\d+[lL]?$', child.value):
                            inferred_type = 'int'
                        else:
                            inferred_type = 'float'  # TODO: complex?
                    elif child.type == token.STRING:
                        if child.value.startswith(('u', 'U')):
                            inferred_type = 'unicode'
                        else:
                            inferred_type = 'str'
                    elif child.type == token.NAME and child.value in ('True', 'False'):
                        inferred_type = 'bool'
                elif child.value == ',':
                    if inferred_type:
                        argtypes.append(stars + inferred_type)
                    # Reset
                    stars = inferred_type = ''
                    in_default = False
                    at_start = False
        if inferred_type:
            argtypes.append(stars + inferred_type)
        return '(' + ', '.join(argtypes) + ') -> ' + restype

    # The parse tree has a different shape when there is a single
    # decorator vs. when there are multiple decorators.
    DECORATED = "decorated< (d=decorator | decorators< dd=decorator+ >) funcdef >"
    decorated = compile_pattern(DECORATED)

    def get_decorators(self, node):
        """Return a list of decorators found on a function definition.

        This is a list of strings; only simple decorators
        (e.g. @staticmethod) are returned.

        If the function is undecorated or only non-simple decorators
        are found, return [].
        """
        if node.parent is None:
            return []
        results = {}
        if not self.decorated.match(node.parent, results):
            return []
        decorators = results.get('dd') or [results['d']]
        decs = []
        for d in decorators:
            for child in d.children:
                if isinstance(child, Leaf) and child.type == token.NAME:
                    decs.append(child.value)
        return decs

    def is_method(self, node):
        """Return whether the node occurs (directly) inside a class."""
        node = node.parent
        while node is not None:
            if node.type == syms.classdef:
                return True
            if node.type == syms.funcdef:
                return False
            node = node.parent
        return False

    RETURN_EXPR = "return_stmt< 'return' any >"
    return_expr = compile_pattern(RETURN_EXPR)

    def has_return_exprs(self, node):
        """Traverse the tree below node looking for 'return expr'.

        Return True if at least 'return expr' is found, False if not.
        (If both 'return' and 'return expr' are found, return True.)
        """
        results = {}
        if self.return_expr.match(node, results):
            return True
        for child in node.children:
            if child.type not in (syms.funcdef, syms.classdef):
                if self.has_return_exprs(child):
                    return True
        return False
Beispiel #10
0
class FuncSignature:
    """A function or method."""

    _full_name: str

    # The pattern to match.
    PATTERN = """
              funcdef<
                'def' name=NAME
                parameters< '(' [args=any+] ')' >
                ['->' ret_annotation=any]
                colon=':' suite=any+ >
              """

    def __init__(self, node, match_results):
        """node must match PATTERN."""

        name = match_results.get('name')
        assert isinstance(name, Leaf), repr(name)
        assert name.type == token.NAME, repr(name)

        self._ret_type = match_results.get('ret_annotation')
        self._full_name = self._make_function_key(name)

        args = self._split_args(match_results.get('args'))
        self._arg_sigs = tuple(map(ArgSignature, args))

        self._node = node
        self._match_results = match_results
        self._inserted_ret_annotation = False

    def __str__(self):
        return self.full_name

    @property
    def full_name(self):
        """Fully-qualified name string."""
        return self._full_name

    @property
    def short_name(self):
        return self._match_results.get('name').value

    @property
    def ret_type(self):
        """Return type, Node? or None."""
        return self._ret_type

    @property
    def arg_sigs(self):
        """List[ArgSignature]."""
        return self._arg_sigs

    # The parse tree has a different shape when there is a single
    # decorator vs. when there are multiple decorators.
    decorated_pattern = compile_pattern("""
    decorated< (d=decorator | decorators< dd=decorator+ >) funcdef >
    """)

    @property
    def decorators(self):
        """A list of the function's decorators.

    This is a list of strings; only simple decorators (e.g. @staticmethod) are
    returned. If the function is undecorated or only non-simple decorators
    are found, return [].

    Returns:
      The names of the function's decorators as a list of strings. Only simple
      decorators (e.g. @staticmethod) are returned. If the function is not
      decorated or only non-simple decorators are found, return [].
    """
        # TODO(tsudol): memoize
        node = self._node
        if node.parent is None:
            return []
        results = {}
        if not self.decorated_pattern.match(node.parent, results):
            return []
        decorators = results.get('dd') or [results['d']]
        decs = []
        for d in decorators:
            for child in d.children:
                if child.type == token.NAME:
                    decs.append(child.value)
        return decs

    @property
    def is_method(self):
        """Whether we are (directly) inside a class."""
        # TODO(tsudol): memoize
        node = self._node.parent
        while node is not None:
            if node.type == syms.classdef:
                return True
            if node.type == syms.funcdef:
                return False
            node = node.parent
        return False

    @property
    def has_return_exprs(self):
        """True if function has "return expr" anywhere."""
        return Util.has_return_exprs(self._node)

    @property
    def has_pep484_annotations(self):
        """Do we have any pep484 annotations?"""
        return self.ret_type or any(arg.arg_type for arg in self.arg_sigs)

    @property
    def has_comment_annotations(self):
        """Do we have any comment annotations?"""
        children = self._match_results['suite'][0].children
        for ch in children:
            if ch.prefix.lstrip().startswith('# type:'):
                return True

        return False

    def insert_ret_annotation(self, ret_type):
        """In-place annotation. Can only be called once."""
        assert not self._inserted_ret_annotation
        self._inserted_ret_annotation = True

        colon = self._match_results.get('colon')
        # TODO(tsudol): insert as a Node, not as a prefix
        colon.prefix = ' -> ' + str(ret_type).strip() + colon.prefix

    def try_insert_comment_annotation(self, annotation):
        """Try to insert '# type: {annotation}' comment."""
        # For reference, see lib2to3/fixes/fix_tuple_params.py in stdlib.
        # "Compact" functions (e.g. 'def foo(x, y): return max(x, y)')
        # are not annotated.

        children = self._match_results['suite'][0].children
        if not (len(children) >= 2 and children[1].type == token.INDENT):
            return False  # can't annotate

        node = children[1]
        node.prefix = '%s# type: %s\n%s' % (node.value, annotation,
                                            node.prefix)
        node.changed()
        return True

    scope_pattern = compile_pattern("""(
    funcdef < 'def'   name=TOKEN any*> |
    classdef< 'class' name=TOKEN any*>
    )""")

    @classmethod
    def _make_function_key(cls, node):
        """Return the fully-qualified name of the function the node is under.

    If source is

    class C:
      def foo(self):
        x = 1

    We'll return 'C.foo' for any nodes related to 'x', '1', 'foo', 'self',
    and either 'C' or '' otherwise.

    Args:
      node: The node to start searching from.

    Returns:
      The function key as a string.
    """
        result = []
        while node is not None:
            match_result = {}
            if cls.scope_pattern.match(node, match_result):
                result.append(match_result['name'].value)

            node = node.parent

        return '.'.join(reversed(result))

    @staticmethod
    def _split_args(args):
        """Turns the match of PATTERN.args into a list of non-empty lists of nodes.

    Args:
      args: The value matched by PATTERN.args.

    Returns:
      A list of non-empty lists of nodes, where each list corresponds to a
      function argument.
    """
        if args is None:
            return []

        assert isinstance(args, list) and len(args) == 1, repr(args)

        args = args[0]
        if isinstance(args, Leaf) or args.type == syms.tname:
            args = [args]
        else:
            args = args.children

        return split_comma(args)
Beispiel #11
0
class BaseFixAnnotate(BaseFix):

    # This fixer is compatible with the bottom matcher.
    BM_compatible = True

    # This fixer shouldn't run by default.
    explicit = True

    # The pattern to match.
    PATTERN = """
              funcdef< 'def' name=any parameters=parameters< '(' [args=any] rpar=')' > ':' suite=any+ >
              """

    _maxfixes = os.getenv('MAXFIXES')
    counter = None if not _maxfixes else int(_maxfixes)
    _type_options = None  # type: Optional[Dict[str, Any]]

    @property
    def type_options(self):
        if self._type_options is None:
            self._type_options = self.options.get('typewriter', {})
        return self._type_options

    def should_skip(self, node, results):
        if BaseFixAnnotate.counter is not None:
            if BaseFixAnnotate.counter <= 0:
                return True

        # Check if there's already a long-form annotation for some argument.
        parameters = results.get('parameters')
        if parameters is not None:
            for ch in parameters.pre_order():
                if ch.prefix.lstrip().startswith('# type:'):
                    return True

        args = results.get('args')
        if args is not None:
            for ch in args.pre_order():
                if ch.prefix.lstrip().startswith('# type:'):
                    return True

        children = results['suite'][0].children

        # NOTE: I've reverse-engineered the structure of the parse tree.
        # It's always a list of nodes, the first of which contains the
        # entire suite.  Its children seem to be:
        #
        #   [0] NEWLINE
        #   [1] INDENT
        #   [2...n-2] statements (the first may be a docstring)
        #   [n-1] DEDENT
        #
        # Comments before the suite are part of the INDENT's prefix.
        #
        # "Compact" functions (e.g. "def foo(x, y): return max(x, y)")
        # have a different structure (no NEWLINE, INDENT, or DEDENT).

        # Check if there's already an annotation.
        for ch in children:
            if ch.prefix.lstrip().startswith('# type:'):
                return True  # There's already a # type: comment here; don't change anything.

        # Python 3 style return annotation are already skipped by the pattern

        # Python 3 style argument annotation structure
        #
        # Structure of the arguments tokens for one positional argument without default value :
        # + LPAR '('
        # + NAME_NODE_OR_LEAF arg1
        # + RPAR ')'
        #
        # NAME_NODE_OR_LEAF is either:
        # 1. Just a leaf with value NAME
        # 2. A node with children: NAME, ':", node expr or value leaf
        #
        # Structure of the arguments tokens for one args with default value or multiple
        # args, with or without default value, and/or with extra arguments :
        # + LPAR '('
        # + node
        #   [
        #     + NAME_NODE_OR_LEAF
        #      [
        #        + EQUAL '='
        #        + node expr or value leaf
        #      ]
        #    (
        #        + COMMA ','
        #        + NAME_NODE_OR_LEAF positional argn
        #      [
        #        + EQUAL '='
        #        + node expr or value leaf
        #      ]
        #    )*
        #   ]
        #   [
        #     + STAR '*'
        #     [
        #     + NAME_NODE_OR_LEAF positional star argument name
        #     ]
        #   ]
        #   [
        #     + COMMA ','
        #     + DOUBLESTAR '**'
        #     + NAME_NODE_OR_LEAF positional keyword argument name
        #   ]
        # + RPAR ')'

        # Let's skip Python 3 argument annotations
        it = iter(args.children) if args else iter([])
        for ch in it:
            if ch.type == token.STAR:
                # *arg part
                ch = next(it)
                if ch.type == token.COMMA:
                    continue
            elif ch.type == token.DOUBLESTAR:
                # *arg part
                ch = next(it)
            if ch.type > 256:
                # this is a node, therefore an annotation
                assert ch.children[0].type == token.NAME
                return True
            try:
                ch = next(it)
                if ch.type == token.COLON:
                    # this is an annotation
                    return True
                elif ch.type == token.EQUAL:
                    ch = next(it)
                    ch = next(it)
                assert ch.type == token.COMMA
                continue
            except StopIteration:
                break

        return False

    def transform(self, node, results):
        if self.should_skip(node, results):
            return

        # Compute the annotation
        annot = self.make_annotation(node, results)
        if annot is None:
            return
        argtypes, restype = annot

        if self.type_options['annotation_style'] == 'py3':
            self.add_py3_annot(argtypes, restype, node, results)
        else:
            self.add_py2_annot(argtypes, restype, node, results)

        # Common to py2 and py3 style annotations:
        if BaseFixAnnotate.counter is not None:
            BaseFixAnnotate.counter -= 1

        # Also add 'from typing import Any' at the top if needed.
        self.patch_imports(argtypes + [restype], node)

    def add_py3_annot(self, argtypes, restype, node, results):
        # type: (List[str], str, Node, Dict[str, Any]) -> None

        args = results.get('args')  # type: Optional[Node]

        argleaves = []  # type: List[Tuple[str, Leaf]]
        if args is None:
            # function with 0 arguments
            it = iter([])  # type: Iterator[Union[Leaf, Node]]
        elif len(args.children) == 0:
            # function with 1 argument
            it = iter([args])
        else:
            # function with multiple arguments or 1 arg with default value
            it = iter(args.children)

        for ch in it:
            argstyle = 'name'
            if ch.type == token.STAR:
                # *arg part
                argstyle = 'star'
                ch = next(it)
                if ch.type == token.COMMA:
                    continue
            elif ch.type == token.DOUBLESTAR:
                # *arg part
                argstyle = 'keyword'
                ch = next(it)
            assert ch.type == token.NAME
            assert isinstance(ch, Leaf)
            argleaves.append((argstyle, ch))
            try:
                ch = next(it)
                if ch.type == token.EQUAL:
                    ch = next(it)
                    ch = next(it)
                assert ch.type == token.COMMA
                continue
            except StopIteration:
                break

        # when self or cls is not annotated, argleaves == argtypes+1
        argleaves = argleaves[len(argleaves) - len(argtypes):]

        for ch_withstyle, chtype in zip(argleaves, argtypes):
            style, ch = ch_withstyle
            if style == 'star':
                assert chtype[0] == '*'
                assert chtype[1] != '*'
                chtype = chtype[1:]
            elif style == 'keyword':
                assert chtype[0:2] == '**'
                assert chtype[2] != '*'
                chtype = chtype[2:]
            ch.value = '%s: %s' % (ch.value, chtype)

            # put spaces around the equal sign
            if ch.next_sibling and ch.next_sibling.type == token.EQUAL:
                nextch = ch.next_sibling
                if not nextch.prefix[:1].isspace():
                    nextch.prefix = ' ' + nextch.prefix
                nextch_ = nextch.next_sibling
                assert nextch_ is not None
                if not nextch_.prefix[:1].isspace():
                    nextch_.prefix = ' ' + nextch_.prefix

        # Add return annotation
        rpar = results['rpar']
        rpar.value = '%s -> %s' % (rpar.value, restype)

        rpar.changed()

    def use_py2_long_form(self, argtypes, short_str, degen_str):
        # type: (List[str], str, str) -> bool
        if self.type_options['comment_style'] == 'single':
            return False
        elif self.type_options['comment_style'] == 'multi':
            return False
        else:  # auto
            return ((len(short_str) > 64 or len(argtypes) > 5)
                    and len(short_str) > len(degen_str))

    def add_py2_annot(self, argtypes, restype, node, results):
        # type: (List[str], str, Node, Dict[str, Any]) -> None

        children = results['suite'][0].children

        # Insert '# type: {annot}' comment.
        # For reference, see lib2to3/fixes/fix_tuple_params.py in stdlib.
        if len(children) >= 1 and children[0].type != token.NEWLINE:
            # one liner function
            if children[0].prefix.strip() == '':
                children[0].prefix = ''
                children.insert(0, Leaf(token.NEWLINE, '\n'))
                children.insert(
                    1, Leaf(token.INDENT, find_indentation(node) + '    '))
                children.append(Leaf(token.DEDENT, ''))

        if len(children) >= 2 and children[1].type == token.INDENT:
            degen_str = '(...) -> %s' % restype
            short_str = '(%s) -> %s' % (', '.join(argtypes), restype)
            if self.use_py2_long_form(argtypes, short_str, degen_str):
                self.insert_long_form(node, results, argtypes)
                annot_str = degen_str
            else:
                annot_str = short_str

            indent_node = children[1]
            comment, sep, other_comments = indent_node.prefix.partition('\n')
            comment = comment.rstrip() + sep
            annot_str = '# type: %s\n' % (annot_str,)
            if comment == annot_str:
                return

            if comment and not is_type_comment(comment):
                # push existing non-type comment to next line
                annot_str += comment

            indent_node.prefix = indent_node.value + annot_str + other_comments
            indent_node.changed()
        else:
            self.log_message("%s:%d: cannot insert annotation for one-line function" %
                             (self.filename, node.get_lineno()))

    def insert_long_form(self, node, results, argtypes):
        # type: (Node, Dict[str, Any], List[str]) -> None

        argtypes = list(argtypes)  # We destroy it
        args = results['args']
        if isinstance(args, Node):
            children = args.children
        elif isinstance(args, Leaf):
            children = [args]
        else:
            children = []
        # Interpret children according to the following grammar:
        # (('*'|'**')? NAME ['=' expr] ','?)*
        flag = False  # Set when the next leaf should get a type prefix
        indent = ''  # Will be set by the first child

        def set_prefix(child):
            if argtypes:
                arg = argtypes.pop(0).lstrip('*')
            else:
                arg = 'Any'  # Somehow there aren't enough args
            if not arg:
                # Skip self (look for 'check_self' below)
                prefix = child.prefix.rstrip()
            else:
                prefix = '  # type: ' + arg
                old_prefix = child.prefix.strip()
                if old_prefix:
                    assert old_prefix.startswith('#')
                    prefix += '  ' + old_prefix
            child.prefix = prefix + '\n' + indent

        check_self = self.is_method(node)
        for child in children:
            if isinstance(child, Leaf):
                if check_self and child.type == token.NAME:
                    check_self = False
                    if child.value in ('self', 'cls'):
                        argtypes.insert(0, '')
                if not indent:
                    indent = ' ' * child.column
                if child.value == ',':
                    flag = True
                elif flag:
                    set_prefix(child)
                    flag = False

        need_comma = len(children) >= 1 and children[-1].type != token.COMMA
        if need_comma and len(children) >= 2:
            if (children[-1].type == token.NAME and
                    (children[-2].type in (token.STAR, token.DOUBLESTAR))):
                need_comma = False
        if need_comma:
            children.append(Leaf(token.COMMA, u","))
        # Find the ')' and insert a prefix before it too.
        parameters = args.parent
        close_paren = parameters.children[-1]
        assert close_paren.type == token.RPAR, close_paren
        set_prefix(close_paren)
        assert not argtypes, argtypes

    def patch_imports(self, types, node):
        for typ in types:
            if 'Any' in typ:
                touch_import('typing', 'Any', node)
                break

    def make_annotation(self, node, results):
        # type: (Node, Dict[str, Any]) -> Optional[Tuple[List[str], str]]
        """Return the type annotations.

        Given the current Note and the dictionary parsed from PATTERN
        return the annoations for the arguments and return types as strings
        """
        raise NotImplementedError

    # The parse tree has a different shape when there is a single
    # decorator vs. when there are multiple decorators.
    DECORATED = "decorated< (d=decorator | decorators< dd=decorator+ >) funcdef >"
    decorated = compile_pattern(DECORATED)

    def get_decorators(self, node):
        """Return a list of decorators found on a function definition.

        This is a list of strings; only simple decorators
        (e.g. @staticmethod) are returned.

        If the function is undecorated or only non-simple decorators
        are found, return [].
        """
        if node.parent is None:
            return []
        results = {}
        if not self.decorated.match(node.parent, results):
            return []
        decorators = results.get('dd') or [results['d']]
        decs = []
        for d in decorators:
            for child in d.children:
                if isinstance(child, Leaf) and child.type == token.NAME:
                    decs.append(child.value)
        return decs

    def is_method(self, node):
        """Return whether the node occurs (directly) inside a class."""
        node = node.parent
        while node is not None:
            if node.type == syms.classdef:
                return True
            if node.type == syms.funcdef:
                return False
            node = node.parent
        return False

    RETURN_EXPR = "return_stmt< 'return' any >"
    return_expr = compile_pattern(RETURN_EXPR)

    def has_return_exprs(self, node):
        """Traverse the tree below node looking for 'return expr'.

        Return True if at least 'return expr' is found, False if not.
        (If both 'return' and 'return expr' are found, return True.)
        """
        results = {}
        if self.return_expr.match(node, results):
            return True
        for child in node.children:
            if child.type not in (syms.funcdef, syms.classdef):
                if self.has_return_exprs(child):
                    return True
        return False

    YIELD_EXPR = "yield_expr< 'yield' [any] >"
    yield_expr = compile_pattern(YIELD_EXPR)

    def is_generator(self, node):
        """Traverse the tree below node looking for 'yield [expr]'."""
        results = {}
        if self.yield_expr.match(node, results):
            return True
        for child in node.children:
            if child.type not in (syms.funcdef, syms.classdef):
                if self.is_generator(child):
                    return True
        return False
Beispiel #12
0
 def register_pattern(self, method, pattern):
     """Register method to handle given pattern.
     """
     self.patterns.append((method, compile_pattern(pattern)))
Beispiel #13
0
class FixMergePyi(BaseFix):
  """Specialized lib2to3 fixer for applying pyi annotations."""

  # This fixer is compatible with the bottom matcher.
  BM_compatible = True

  # This fixer shouldn't run by default.
  explicit = True

  PATTERN = FuncSignature.PATTERN

  def __init__(self, options, log):
    super(FixMergePyi, self).__init__(options, log)

    # ParsedPyi obtained from .pyi file
    self.parsed_pyi = None

    # Did we add globals required by pyi to the top of the py file
    self.added_pyi_globals = False

    self.logger = logging.getLogger(self.__class__.__name__)

    # Options below

    # insert type annotations in PEP484 style. Otherwise insert as comments
    self._annotate_pep484 = False

  @property
  def annotate_pep484(self):
    return self._annotate_pep484

  @annotate_pep484.setter
  def annotate_pep484(self, value):
    self._annotate_pep484 = bool(value)

  def transform(self, node, results):
    assert self.parsed_pyi, 'must provide pyi_string'

    src_sig = FuncSignature(node, results)
    if not self.can_annotate(src_sig):
      return
    pyi_sig = self.parsed_pyi.funcs[src_sig.full_name]

    if self.annotate_pep484:
      self.insert_annotation(src_sig, pyi_sig)
    else:
      annot = self.get_comment_annotation(src_sig, pyi_sig)
      if src_sig.try_insert_comment_annotation(annot) and 'Any' in annot:
        touch_import('typing', 'Any', node)

    self.add_globals(node)

  def insert_annotation(self, src_sig, pyi_sig):
    """Insert annotation in PEP484 format."""
    for arg_sig, pyi_arg_sig in zip(src_sig.arg_sigs, pyi_sig.arg_sigs):
      if not pyi_arg_sig.arg_type:
        continue
      new_type = clean_clone(pyi_arg_sig.arg_type, False)
      arg_sig.insert_annotation(new_type)

    if pyi_sig.ret_type:
      src_sig.insert_ret_annotation(pyi_sig.ret_type)

  def get_comment_annotation(self, src_sig, pyi_sig):
    """Return function annotation as a comment string, doesn't modify tree."""
    arg_types = []
    for i, (arg_sig, pyi_arg_sig) in enumerate(
        zip(src_sig.arg_sigs, pyi_sig.arg_sigs)):
      is_first = (i == 0)
      new_type = clean_clone(pyi_arg_sig.arg_type, True)

      if new_type:
        new_type_str = str(new_type).strip()
      elif self.infer_should_annotate(src_sig, arg_sig, is_first):
        new_type_str = 'Any'
      else:
        continue

      arg_types.append(arg_sig.stars + new_type_str)

    ret_type = pyi_sig.ret_type
    if not ret_type:
      ret_type = self.infer_ret_type(src_sig)

    return '(' + ', '.join(arg_types) + ') -> ' + str(ret_type).strip()

  def can_annotate(self, src_sig):
    if src_sig.has_pep484_annotations or src_sig.has_comment_annotations:
      self.logger.warning('already annotated, skipping %s', src_sig)
      return False

    if src_sig.full_name not in self.parsed_pyi.funcs:
      self.logger.warning('no signature for %s, skipping', src_sig)
      return False

    pyi_sig = self.parsed_pyi.funcs[src_sig.full_name]

    if not pyi_sig.has_pep484_annotations:
      self.logger.warning('ignoring pyi definition with no annotations: %s',
                          pyi_sig)
      return False

    if not self.func_sig_compatible(src_sig, pyi_sig):
      self.logger.warning('incompatible annotation, skipping %s', src_sig)
      return False

    return True

  def add_globals(self, node):
    """Add required globals to the root of node. Idempotent."""
    if self.added_pyi_globals:
      return
    # TODO(tsudol): get rid of this -- added to prevent adding
    # .parsed_pyi.top_lines every time we annotate a different function in the
    # same file, but can break when we run the tool twice on the same file. Have
    # to do something like what touch_import does.
    self.added_pyi_globals = True

    imports, top_lines = self.parsed_pyi.imports, self.parsed_pyi.top_lines

    # Copy imports if not already present
    for pkg, names in imports:
      if names is None:
        # TODO(tsudol): do ourselves, touch_import puts stuff above license
        # headers.
        touch_import(None, pkg, node)  # == 'import pkg'
      else:
        for name in names:
          touch_import(pkg, name, node)

    root = find_root(node)

    import_idx = [
        idx for idx, node in enumerate(root.children)
        if self.import_pattern.match(node)
    ]
    if import_idx:
      insert_pos = import_idx[-1] + 1
    else:
      insert_pos = 0

      # first string (normally docstring)
      for idx, node in enumerate(root.children):
        if (node.type == syms.simple_stmt and node.children and
            node.children[0].type == token.STRING):
          insert_pos = idx + 1
          break

    top_lines = '\n'.join(top_lines)
    top_lines = Util.parse_string(top_lines)  # strips some newlines
    for offset, node in enumerate(top_lines.children[:-1]):
      root.insert_child(insert_pos + offset, node)

  @staticmethod
  def func_sig_compatible(src_sig, pyi_sig):
    """Can src_sig be annotated with the info in pyi_sig?

    For the two signatures to be compatible, the number of arguments
    must match, they must have the same star args and they can't be tuple
    arguments.

    Args:
      src_sig: A FuncSignature representing the .py signature.
      pyi_sig: A FuncSignature representing the .pyi signature.

    Returns:
      True if the two signatures are compatible, False otherwise.
    """
    if len(pyi_sig.arg_sigs) != len(src_sig.arg_sigs):
      return False

    for pyi, cur in zip(pyi_sig.arg_sigs, src_sig.arg_sigs):
      # Entirely skip functions that use tuple args
      if cur.is_tuple or pyi.is_tuple:
        return False

      # Stars are expected to match
      if cur.stars != pyi.stars:
        return False

    return True

  @staticmethod
  def infer_ret_type(src_sig):
    """Heuristic for return type of a function."""
    if src_sig.short_name == '__init__' or not src_sig.has_return_exprs:
      return 'None'
    return 'Any'

  @staticmethod
  def infer_should_annotate(func, arg, at_start):
    """Heuristic for whether arg, in func, should be annotated."""

    if func.is_method and at_start and 'staticmethod' not in func.decorators:
      # Don't annotate the first argument if it's named 'self'.
      # Don't annotate the first argument of a class method.
      if 'self' == arg.name or 'classmethod' in func.decorators:
        return False

    return True

  def set_pyi_string(self, pyi_string):
    """Set the annotations the fixer will use."""
    self.parsed_pyi = self.parse_pyi_string(pyi_string)
    self.added_pyi_globals = False

  def parse_pyi_string(self, text):
    """Parse .pyi string, return as ParsedPyi."""
    tree = Util.parse_string(text)

    funcs = {}
    for node, match_results in generate_matches(tree, self.pattern):
      sig = FuncSignature(node, match_results)

      if sig.full_name in funcs:
        self.logger.warning('Ignoring redefinition: %s', sig)
      else:
        funcs[sig.full_name] = sig

    imports = []
    for node, match_results in generate_top_matches(tree, self.import_pattern):
      imp = self.parse_top_import(node, match_results)
      if imp:
        imports.append(imp)

    top_lines = []
    for node, match_results in generate_top_matches(tree, self.assign_pattern):
      text = str(node).strip()

      # hack to avoid shadowing real variables -- proper solution is more
      # complicated, use util.find_binding
      if 'TypeVar' in text or (text and '_' == text[0]):
        top_lines.append(text)
      else:
        self.logger.warning('ignoring %s', repr(text))

    return ParsedPyi(tuple(imports), top_lines, funcs)

  assign_pattern = compile_pattern("""
    simple_stmt< expr_stmt<any+> any* >
    """)

  import_pattern = compile_pattern("""
    simple_stmt<
        ( import_from< 'from' pkg=any+ 'import' ['('] names=any [')'] > |
          import_name< 'import' pkg=any+ > )
        any*
    >
    """)
  import_as_pattern = compile_pattern("""import_as_name<NAME 'as' NAME>""")

  def parse_top_import(self, node, results):
    """Splits the result of import_pattern into component strings.

    Examples:

    'from pkg import a,b,c' gives
    ('pkg', ('a', 'b', 'c'))

    'import pkg' gives
    ('pkg', None)

    'from pkg import a as b' or 'import pkg as pkg2' are not supported.

    Args:
      node: The import statement node.
      results: The values from import_pattern.

    Returns:
      A tuple of the package name (string) and the list of imported names (list
      of strings).
    """

    # TODO(tsudol): this might have to be generalized to "get top-level
    # statements that aren't class or function definitions":
    # _T = typing.TypeVar('_T') is used in pyis.
    # Still not clear what is and isn't valid in a pyi... Could we have a loop?

    pkg, names = results['pkg'], results.get('names', None)
    pkg = ''.join(map(str, pkg)).strip()

    if names:
      is_import_as = any(
          True for _ in generate_matches(names, self.import_as_pattern))

      if is_import_as:
        # fixer_util.touch_import doesn't handle this
        # If necessary, will have to stick import at top of .py file
        self.logger.warning('Ignoring unhandled import-as: %s',
                            repr(str(node).strip()))
        return None

      names = split_comma(names.leaves())
      for name in names:
        assert 1 == len(name)
        assert name[0].type in (token.NAME, token.STAR)
      names = [name[0].value for name in names]

    return pkg, names
Beispiel #14
0
    'print'          into 'print()'
    'print ...'      into 'print(...)'
    'print ... ,'    into 'print(..., end=" ")'
    'print >>x, ...' into 'print(..., file=x)'

No changes are applied if print_function is imported from __future__

"""

# Local imports
from lib2to3 import patcomp, pytree, fixer_base
from lib2to3.pgen2 import token
from lib2to3.fixer_util import Name, Call, Comma, String
from libmodernize import add_future

parend_expr = patcomp.compile_pattern("""atom< '(' [atom|STRING|NAME] ')' >""")


class FixPrint(fixer_base.BaseFix):

    BM_compatible = True

    PATTERN = """
              simple_stmt< any* bare='print' any* > | print_stmt
              """

    def transform(self, node, results):
        assert results

        bare_print = results.get("bare")
def build_pattern(mapping=MAPPING):
    PATTERN = """
    power< local=(%s)
         tail=any*
    >
    """
    return PATTERN % "|".join("'%s'" % key.split(".")[-1]
                              for key in mapping.keys())


## Pattern for finding from module import *
from_import_pattern = """import_from<'from' module=(%s) 'import' star='*'>"""
module_names = set(["'%s'" % key.split(".", 1)[0] for key in MAPPING.keys()])
from_import_pattern = from_import_pattern % "|".join(module_names)

from_import_expr = patcomp.compile_pattern(from_import_pattern)


class FixChangedNamesAggressive(FixChangedNames):
    mapping = MAPPING
    run_order = 2

    def compile_pattern(self):
        # We override this, so MAPPING can be pragmatically altered and the
        # changes will be reflected in PATTERN.
        self.PATTERN = build_pattern(self.mapping)
        name2mod = defaultdict(set)
        for key in self.mapping.keys():
            mod, name = key.split(".", 1)
            name2mod[name].add(mod)
        self._names_to_modules = name2mod
Beispiel #16
0
    'print >>x, ...' into 'print(..., file=x)'

No changes are applied if print_function is imported from __future__

"""

from __future__ import unicode_literals

from lib2to3 import patcomp, pytree, fixer_base
from lib2to3.pgen2 import token
from lib2to3.fixer_util import Name, Call, Comma, FromImport, Newline, String

from libmodernize import check_future_import

parend_expr = patcomp.compile_pattern(
    """atom< '(' [arith_expr|atom|power|term|STRING|NAME] ')' >"""
)


class FixPrint(fixer_base.BaseFix):

    BM_compatible = True

    PATTERN = """
              simple_stmt< any* bare='print' any* > | print_stmt
              """

    def start_tree(self, tree, filename):
        self.found_print = False

    def transform(self, node, results):
Beispiel #17
0
    "print(...)"     not changed
    "print ... ,"    into "print(..., end=' ')"
    "print >>x, ..." into "print(..., file=x)"

No changes are applied if print_function is imported from __future__

"""

# Local imports
from lib2to3 import patcomp, pytree, fixer_base
from lib2to3.pgen2 import token
from lib2to3.fixer_util import Name, Call, Comma, String

# from libmodernize import add_future

parend_expr = patcomp.compile_pattern(
    """atom< '(' [arith_expr|atom|power|term|STRING|NAME] ')' >""")


class FixPrint(fixer_base.BaseFix):

    BM_compatible = True

    PATTERN = """
              simple_stmt< any* bare='print' any* > | print_stmt
              """

    def transform(self, node, results):
        assert results

        bare_print = results.get("bare")
Beispiel #18
0
class FixAnnotate(BaseFix):

    # This fixer is compatible with the bottom matcher.
    BM_compatible = True

    # This fixer shouldn't run by default.
    explicit = True

    # The pattern to match.
    PATTERN = """
              funcdef< 'def' name=any parameters< '(' [args=any] ')' > ':' suite=any+ >
              """

    counter = None if not os.getenv('MAXFIXES') else int(os.getenv('MAXFIXES'))

    def transform(self, node, results):
        if FixAnnotate.counter is not None:
            if FixAnnotate.counter <= 0:
                return
        suite = results['suite']
        children = suite[0].children

        # NOTE: I've reverse-engineered the structure of the parse tree.
        # It's always a list of nodes, the first of which contains the
        # entire suite.  Its children seem to be:
        #
        #   [0] NEWLINE
        #   [1] INDENT
        #   [2...n-2] statements (the first may be a docstring)
        #   [n-1] DEDENT
        #
        # Comments before the suite are part of the INDENT's prefix.
        #
        # "Compact" functions (e.g. "def foo(x, y): return max(x, y)")
        # have a different structure that isn't matched by PATTERN.

        ## print('-'*60)
        ## print(node)
        ## for i, ch in enumerate(children):
        ##     print(i, repr(ch.prefix), repr(ch))

        # Check if there's already an annotation.
        for ch in children:
            if ch.prefix.lstrip().startswith('# type:'):
                return  # There's already a # type: comment here; don't change anything.

        # Compute the annotation
        annot = self.make_annotation(node, results)
        if annot is None:
            return

        # Insert '# type: {annot}' comment.
        # For reference, see lib2to3/fixes/fix_tuple_params.py in stdlib.
        if len(children) >= 2 and children[1].type == token.INDENT:
            argtypes, restype = annot
            degen_str = '(...) -> %s' % restype
            short_str = '(%s) -> %s' % (', '.join(argtypes), restype)
            if (len(short_str) > 64 or len(argtypes) > 5) and len(short_str) > len(degen_str):
                self.insert_long_form(node, results, argtypes)
                annot_str = degen_str
            else:
                annot_str = short_str
            children[1].prefix = '%s# type: %s\n%s' % (children[1].value, annot_str,
                                                       children[1].prefix)
            children[1].changed()
            if FixAnnotate.counter is not None:
                FixAnnotate.counter -= 1

            # Also add 'from typing import Any' at the top if needed.
            self.patch_imports(argtypes + [restype], node)

    def insert_long_form(self, node, results, argtypes):
        argtypes = list(argtypes)  # We destroy it
        args = results['args']
        if isinstance(args, Node):
            children = args.children
        elif isinstance(args, Leaf):
            children = [args]
        else:
            children = []
        # Interpret children according to the following grammar:
        # (('*'|'**')? NAME ['=' expr] ','?)*
        flag = False  # Set when the next leaf should get a type prefix
        indent = ''  # Will be set by the first child

        def set_prefix(child):
            if argtypes:
                arg = argtypes.pop(0).lstrip('*')
            else:
                arg = 'Any'  # Somehow there aren't enough args
            if not arg:
                # Skip self (look for 'check_self' below)
                prefix = child.prefix.rstrip()
            else:
                prefix = '  # type: ' + arg
                old_prefix = child.prefix.strip()
                if old_prefix:
                    assert old_prefix.startswith('#')
                    prefix += '  ' + old_prefix
            child.prefix = prefix + '\n' + indent

        check_self = self.is_method(node)
        for child in children:
            if check_self and isinstance(child, Leaf) and child.type == token.NAME:
                check_self = False
                if child.value in ('self', 'cls'):
                    argtypes.insert(0, '')
            if not indent:
                indent = ' ' * child.column
            if isinstance(child, Leaf) and child.value == ',':
                flag = True
            elif isinstance(child, Leaf) and flag:
                set_prefix(child)
                flag = False
        # Find the ')' and insert a prefix before it too.
        parameters = args.parent
        close_paren = parameters.children[-1]
        assert close_paren.type == token.RPAR, close_paren
        set_prefix(close_paren)
        assert not argtypes, argtypes

    def patch_imports(self, types, node):
        for typ in types:
            if 'Any' in typ:
                touch_import('typing', 'Any', node)
                break

    def make_annotation(self, node, results):
        name = results['name']
        assert isinstance(name, Leaf), repr(name)
        assert name.type == token.NAME, repr(name)
        decorators = self.get_decorators(node)
        is_method = self.is_method(node)
        if name.value == '__init__' or not self.has_return_exprs(node):
            restype = 'None'
        else:
            restype = 'Any'
        args = results.get('args')
        argtypes = []
        if isinstance(args, Node):
            children = args.children
        elif isinstance(args, Leaf):
            children = [args]
        else:
            children = []
        # Interpret children according to the following grammar:
        # (('*'|'**')? NAME ['=' expr] ','?)*
        stars = inferred_type = ''
        in_default = False
        at_start = True
        for child in children:
            if isinstance(child, Leaf):
                if child.value in ('*', '**'):
                    stars += child.value
                elif child.type == token.NAME and not in_default:
                    if not is_method or not at_start or 'staticmethod' in decorators:
                        inferred_type = 'Any'
                    else:
                        # Always skip the first argument if it's named 'self'.
                        # Always skip the first argument of a class method.
                        if  child.value == 'self' or 'classmethod' in decorators:
                            pass
                        else:
                            inferred_type = 'Any'
                elif child.value == '=':
                    in_default = True
                elif in_default and child.value != ',':
                    if child.type == token.NUMBER:
                        if re.match(r'\d+[lL]?$', child.value):
                            inferred_type = 'int'
                        else:
                            inferred_type = 'float'  # TODO: complex?
                    elif child.type == token.STRING:
                        if child.value.startswith(('u', 'U')):
                            inferred_type = 'unicode'
                        else:
                            inferred_type = 'str'
                    elif child.type == token.NAME and child.value in ('True', 'False'):
                        inferred_type = 'bool'
                elif child.value == ',':
                    if inferred_type:
                        argtypes.append(stars + inferred_type)
                    # Reset
                    stars = inferred_type = ''
                    in_default = False
                    at_start = False
        if inferred_type:
            argtypes.append(stars + inferred_type)
        return argtypes, restype

    # The parse tree has a different shape when there is a single
    # decorator vs. when there are multiple decorators.
    DECORATED = "decorated< (d=decorator | decorators< dd=decorator+ >) funcdef >"
    decorated = compile_pattern(DECORATED)

    def get_decorators(self, node):
        """Return a list of decorators found on a function definition.

        This is a list of strings; only simple decorators
        (e.g. @staticmethod) are returned.

        If the function is undecorated or only non-simple decorators
        are found, return [].
        """
        if node.parent is None:
            return []
        results = {}
        if not self.decorated.match(node.parent, results):
            return []
        decorators = results.get('dd') or [results['d']]
        decs = []
        for d in decorators:
            for child in d.children:
                if isinstance(child, Leaf) and child.type == token.NAME:
                    decs.append(child.value)
        return decs

    def is_method(self, node):
        """Return whether the node occurs (directly) inside a class."""
        node = node.parent
        while node is not None:
            if node.type == syms.classdef:
                return True
            if node.type == syms.funcdef:
                return False
            node = node.parent
        return False

    RETURN_EXPR = "return_stmt< 'return' any >"
    return_expr = compile_pattern(RETURN_EXPR)

    def has_return_exprs(self, node):
        """Traverse the tree below node looking for 'return expr'.

        Return True if at least 'return expr' is found, False if not.
        (If both 'return' and 'return expr' are found, return True.)
        """
        results = {}
        if self.return_expr.match(node, results):
            return True
        for child in node.children:
            if child.type not in (syms.funcdef, syms.classdef):
                if self.has_return_exprs(child):
                    return True
        return False

    YIELD_EXPR = "yield_expr< 'yield' [any] >"
    yield_expr = compile_pattern(YIELD_EXPR)

    def is_generator(self, node):
        """Traverse the tree below node looking for 'yield [expr]'."""
        results = {}
        if self.yield_expr.match(node, results):
            return True
        for child in node.children:
            if child.type not in (syms.funcdef, syms.classdef):
                if self.is_generator(child):
                    return True
        return False
Beispiel #19
0
class Pyi(collections.namedtuple('Pyi', 'imports assignments funcs')):
    """A parsed pyi."""
    def _get_imports(self, inserted_types):
        """Get the imports that provide the given types."""
        used_names = set()
        for node in inserted_types + self.assignments:
            for leaf in node.leaves():
                if leaf.type == token.NAME:
                    used_names.add(leaf.value)
                    # All prefixes are possible imports.
                    while '.' in leaf.value:
                        value, _ = leaf.rsplit('.', 1)
                        used_names.add(value)
        for (pkg, pkg_alias), names in self.imports:
            if not names:
                if (pkg_alias or pkg) in used_names:
                    yield ((pkg, pkg_alias), names)
            else:
                names = [(name, alias) for name, alias in names
                         if name == '*' or (alias or name) in used_names]
                if names:
                    yield ((pkg, pkg_alias), names)

    def add_globals(self, tree, inserted_types):
        """Add required globals to the tree. Idempotent."""
        # Copy imports if not already present
        top_lines = []

        def import_name(name, alias):
            return name + ('' if alias is None else ' as %s' % alias)

        for (pkg, pkg_alias), names in self._get_imports(inserted_types):
            if not names:
                if does_tree_import(None, pkg_alias or pkg, tree):
                    continue
                top_lines.append('import %s\n' % import_name(pkg, pkg_alias))
            else:
                assert pkg_alias is None
                import_names = []
                for name, alias in names:
                    if does_tree_import(pkg, alias or name, tree):
                        continue
                    import_names.append(import_name(name, alias))
                if not import_names:
                    continue
                top_lines.append('from %s import %s\n' %
                                 (pkg, ', '.join(import_names)))

        import_idx = [
            idx for idx, idx_node in enumerate(tree.children)
            if self.import_pattern.match(idx_node)
        ]
        if import_idx:
            insert_pos = import_idx[-1] + 1
        else:
            insert_pos = 0

            # first string (normally docstring)
            for idx, idx_node in enumerate(tree.children):
                if (idx_node.type == syms.simple_stmt and idx_node.children
                        and idx_node.children[0].type == token.STRING):
                    insert_pos = idx + 1
                    break

        if self.assignments:
            top_lines.append('\n')
            top_lines.extend(str(a).strip() + '\n' for a in self.assignments)
        top_lines = Util.parse_string(''.join(top_lines))
        for offset, offset_node in enumerate(top_lines.children[:-1]):
            tree.insert_child(insert_pos + offset, offset_node)

    @classmethod
    def _log_warning(cls, *args):
        logger = logging.getLogger(cls.__name__)
        logger.warning(*args)

    @classmethod
    def parse(cls, text):
        """Parse .pyi string, return as Pyi."""
        tree = Util.parse_string(text)

        funcs = {}
        for node, match_results in generate_matches(tree,
                                                    cls.function_pattern):
            sig = FuncSignature(node, match_results)

            if sig.full_name in funcs:
                cls._log_warning('Ignoring redefinition: %s', sig)
            else:
                funcs[sig.full_name] = sig

        imports = []
        # Any is sometimes inserted as a default type, so make sure typing.Any is
        # always importable.
        any_import = False
        for node, match_results in generate_top_matches(
                tree, cls.import_pattern):
            pkg, names = cls.parse_top_import(match_results)
            if pkg == ('typing', None) and names:
                if ('Any', None) not in names:
                    names.insert(0, ('Any', None))
                any_import = True
            imports.append((pkg, names))
        if not any_import:
            imports.append((('typing', None), [('Any', None)]))

        assignments = []
        for node, match_results in generate_top_matches(
                tree, cls.assign_pattern):
            text = str(node)

            # hack to avoid shadowing real variables -- proper solution is more
            # complicated, use util.find_binding
            if 'TypeVar' in text or (text and text[0] == '_'):
                assignments.append(node)
            else:
                cls._log_warning('ignoring %s', repr(text))

        return cls(tuple(imports), tuple(assignments), funcs)

    function_pattern = compile_pattern(FuncSignature.PATTERN)

    assign_pattern = compile_pattern("""
    simple_stmt< expr_stmt<any+> any* >
    """)

    import_pattern = compile_pattern("""
    simple_stmt<
        ( import_from< 'from' pkg=any+ 'import' ['('] names=any [')'] > |
          import_name< 'import' pkg=any+ > )
        any*
    >
    """)

    @classmethod
    def _parse_import_alias(cls, leaves):
        assert leaves[-2].value == 'as'
        name = ''.join(leaf.value for leaf in leaves[:-2])
        return (name, leaves[-1].value)

    @classmethod
    def parse_top_import(cls, results):
        """Splits the result of import_pattern into component strings.

    Examples:

    'from pkg import a,b,c' gives
    (('pkg', None), [('a', None), ('b', None), ('c', None)])

    'import pkg' gives
    (('pkg', None), [])

    'from pkg import a as b' gives
    (('pkg', None), [('a', 'b')])

    'import pkg as pkg2' gives
    (('pkg', 'pkg2'), [])

    'import pkg.a as b' gives
    (('pkg.a', 'b'), [])

    Args:
      results: The values from import_pattern.

    Returns:
      A tuple of the package name and the list of imported names. Each name is a
      tuple of original name and alias.
    """

        pkg, names = results['pkg'], results.get('names', None)

        if len(pkg
               ) == 1 and pkg[0].type == pygram.python_symbols.dotted_as_name:
            pkg_out = cls._parse_import_alias(list(pkg[0].leaves()))
        else:
            pkg_out = (''.join(map(str, pkg)).strip(), None)

        names_out = []
        if names:
            names = split_comma(names.leaves())
            for name in names:
                if len(name) == 1:
                    assert name[0].type in (token.NAME, token.STAR)
                    names_out.append((name[0].value, None))
                else:
                    names_out.append(cls._parse_import_alias(name))

        return pkg_out, names_out
    
def build_pattern(mapping=MAPPING):
    PATTERN = """
    power< local=(%s)
         tail=any*
    >
    """ 
    return PATTERN % "|".join("'%s'" % key.split(".")[-1] for key in mapping.keys())


## Pattern for finding from module import * 
from_import_pattern = """import_from<'from' module=(%s) 'import' star='*'>"""
module_names = set(["'%s'" % key.split(".", 1)[0] for key in MAPPING.keys()])
from_import_pattern = from_import_pattern % "|".join(module_names)

from_import_expr = patcomp.compile_pattern(from_import_pattern)
 
class FixChangedNamesAggressive(FixChangedNames):
    mapping = MAPPING 
    run_order = 2
    
    def compile_pattern(self):
        # We override this, so MAPPING can be pragmatically altered and the
        # changes will be reflected in PATTERN.
        self.PATTERN = build_pattern(self.mapping)
        name2mod = defaultdict(set)
        for key in self.mapping.keys():
            mod, name = key.split(".", 1)
            name2mod[name].add(mod)
        self._names_to_modules = name2mod
        
Beispiel #21
0
    'print ...'      into 'print(...)'
    'print ... ,'    into 'print(..., end=" ")'
    'print >>x, ...' into 'print(..., file=x)'

No changes are applied if print_function is imported from __future__

"""

# Local imports
from lib2to3 import patcomp, pytree, fixer_base
from lib2to3.pgen2 import token
from lib2to3.fixer_util import Name, Call, Comma, String
from libmodernize import add_future

parend_expr = patcomp.compile_pattern(
              """atom< '(' [atom|STRING|NAME] ')' >"""
              )


class FixPrint(fixer_base.BaseFix):

    BM_compatible = True

    PATTERN = """
              simple_stmt< any* bare='print' any* > | print_stmt
              """

    def transform(self, node, results):
        assert results

        bare_print = results.get("bare")
Beispiel #22
0
class FixDict(fixer_base.BaseFix):
    BM_compatible = True

    PATTERN = """
    power< head=any+
         trailer< '.' method=('keys'|'items'|'values'|
                              'iterkeys'|'iteritems'|'itervalues'|
                              'viewkeys'|'viewitems'|'viewvalues') >
         parens=trailer< '(' ')' >
         tail=any*
    >
    """

    def transform(self, node, results):
        head = results["head"]
        method = results["method"][0]  # Extract node for method name
        tail = results["tail"]
        syms = self.syms
        method_name = method.value
        isiter = method_name.startswith(u"iter")
        isview = method_name.startswith(u"view")
        head = [n.clone() for n in head]
        tail = [n.clone() for n in tail]
        # no changes neccessary if the call is in a special context
        special = not tail and self.in_special_context(node, isiter)
        new = pytree.Node(syms.power, head)
        new.prefix = u""
        if isiter or isview:
            # replace the method with the six function
            # e.g. d.iteritems() -> from six import iteritems\n iteritems(d)
            new = Call(Name(method_name), [new])
            touch_import_top('six', method_name, node)
        elif special:
            # it is not neccessary to change this case
            return node
        elif method_name in ("items", "values"):
            # ensure to return a list in python 3
            new = Call(Name(u"list" + method_name), [new])
            touch_import_top('future.utils', 'list' + method_name, node)
        else:
            # method_name is "keys"; removed it and cast the dict to list
            new = Call(Name(u"list"), [new])

        if tail:
            new = pytree.Node(syms.power, [new] + tail)
        new.prefix = node.prefix
        return new

    P1 = "power< func=NAME trailer< '(' node=any ')' > any* >"
    p1 = patcomp.compile_pattern(P1)

    def in_special_context(self, node, isiter):
        # it is not wrapped
        if node.parent is None:
            return False
        results = {}
        if (node.parent.parent is not None
                and self.p1.match(node.parent.parent, results)
                and results["node"] is node):

            if isiter:
                # iter(d.iterkeys()) -> iter(d.keys()), etc.
                return results["func"].value in iter_exempt
            else:
                # list(d.keys()) -> list(d.keys()), etc.
                return results["func"].value in fixer_util.consuming_calls
        return False