def transform(self, node, results):
        suite = results['suite'].clone()
        # todo: handle tabs
        dedent = len(find_indentation(suite)) - len(find_indentation(node))
        self.dedent(suite, dedent)

        # remove the first newline behind the classdef header
        first = suite.children[0]
        if first.type == token.NEWLINE:
            if len(first.value) == 1:
                del suite.children[0]
            else:
                first.value == first.value[1:]

        return suite
示例#2
0
    def transform(self, node, results):
        suite = results['suite'].clone()
        # todo: handle tabs
        dedent = len(find_indentation(suite)) - len(find_indentation(node))
        self.dedent(suite, dedent)

        # remove the first newline behind the classdef header
        first = suite.children[0]
        if first.type == token.NEWLINE:
            if len(first.value) == 1:
                del suite.children[0]
            else:
                first.value == first.value[1:]

        return suite
示例#3
0
    def add_py2_annot(self, argtypes, restype, node, results):
        children = results['suite'][0].children

        # Insert '# type: {annot}' comment.
        # For reference, see lib2to3/fixes/fix_tuple_params.py in stdlib.
        if len(children) >= 1 and children[0].type != token.NEWLINE:
            # one liner function
            if children[0].prefix.strip() == '':
                children[0].prefix = ''
                children.insert(0, Leaf(token.NEWLINE, '\n'))
                children.insert(
                    1, Leaf(token.INDENT,
                            find_indentation(node) + '    '))
                children.append(Leaf(token.DEDENT, ''))
        if len(children) >= 2 and children[1].type == token.INDENT:
            degen_str = '(...) -> %s' % restype
            short_str = '(%s) -> %s' % (', '.join(argtypes), restype)
            if (len(short_str) > 64
                    or len(argtypes) > 5) and len(short_str) > len(degen_str):
                self.insert_long_form(node, results, argtypes)
                annot_str = degen_str
            else:
                annot_str = short_str
            children[1].prefix = '%s# type: %s\n%s' % (
                children[1].value, annot_str, children[1].prefix)
            children[1].changed()
        else:
            self.log_message(
                "%s:%d: cannot insert annotation for one-line function" %
                (self.filename, node.get_lineno()))
示例#4
0
    def transform(self, node, results):
        # Sometimes newlines are in prefix of current node, sometimes they're
        # in prefix of the prev sibling
        if node.prefix.count('\n'):
            newline_node = node
        else:
            newline_node = get_whitespace_before_definition(node)
            if not newline_node:
                # No previous node, must be the first node.
                return

        if newline_node.type in [token.INDENT, token.NEWLINE]:
            # If the newline_node is an indent or newline, we don't need to
            # worry about fixing indentation since it is not part of the
            # prefix. Dedents do have it as part of the prefix.
            curr_node_indentation = ''
        else:
            curr_node_indentation = find_indentation(node)
        min_lines_between_defs, max_lines_between_defs = (self.
            get_newline_limits(node))
        new_prefix = self.trim_comments(curr_node_indentation, newline_node.
            prefix, min_lines_between_defs, max_lines_between_defs)

        if newline_node.prefix != new_prefix:
            newline_node.prefix = new_prefix
            newline_node.changed()
示例#5
0
    def transform(self, node, results):

        def process_arg(arg):
            if isinstance(arg, Leaf) and arg.type == token.COMMA:
                return
            elif isinstance(arg, Node) and arg.type == syms.argument:
                # keyword argument
                name, equal, value = arg.children
                assert name.type == token.NAME # what is the symbol for 1?
                assert equal.type == token.EQUAL # what is the symbol for 1?
                value = value.clone()
                value.prefix = " "
                kwargs[name.value] = value
            else:
                assert not kwargs, 'all positional args are assumed to come first'
                posargs.append(arg.clone())

        method = results['method'][0].value
        # map (deprecated) aliases to original to avoid analysing
        # the decorator function
        method = _method_aliases.get(method, method)

        posargs = []
        kwargs = {}

        # This is either a "arglist" or a single argument
        if results['arglist'].type == syms.arglist:
            for arg in results['arglist'].children:
                process_arg(arg)
        else:
            process_arg(results['arglist'])

        try:
            test_func = getattr(unittest.TestCase, method)
        except AttributeError:
            raise RuntimeError("Your unittest package does not support '%s'. "
                               "consider updating the package" % method)

        required_args, argsdict = utils.resolve_func_args(test_func, posargs, kwargs)

        if method.startswith(('assertRaises', 'assertWarns')):
            n_stmt = _method_map[method](*required_args,
                                         indent=find_indentation(node),
                                         kws=argsdict,
                                         arglist=results['arglist'])
        else:
            n_stmt = Node(syms.assert_stmt,
                          [Name('assert'),
                           _method_map[method](*required_args, kws=argsdict)])
        if argsdict.get('msg', None) is not None:
            n_stmt.children.extend((Name(','), argsdict['msg']))
        n_stmt.prefix = node.prefix

        return n_stmt
    def transform(self, node, results):

        def process_arg(arg):
            if isinstance(arg, Leaf) and arg.type == token.COMMA:
                return
            elif isinstance(arg, Node) and arg.type == syms.argument:
                # keyword argument
                name, equal, value = arg.children
                assert name.type == token.NAME # what is the symbol for 1?
                assert equal.type == token.EQUAL # what is the symbol for 1?
                value = value.clone()
                value.prefix = " "
                kwargs[name.value] = value
            else:
                assert not kwargs, 'all positional args are assumed to come first'
                posargs.append(arg.clone())

        method = results['method'][0].value
        # map (deprecated) aliases to original to avoid analysing
        # the decorator function
        method = _method_aliases.get(method, method)

        posargs = []
        kwargs = {}

        # This is either a "arglist" or a single argument
        if results['arglist'].type == syms.arglist:
            for arg in results['arglist'].children:
                process_arg(arg)
        else:
            process_arg(results['arglist'])
        
        try:
            test_func = getattr(unittest.TestCase, method)
        except AttributeError:
            raise RuntimeError("Your unittest package does not support '%s'. "
                               "consider updating the package" % method)

        required_args, argsdict = utils.resolve_func_args(test_func, posargs, kwargs)

        if method.startswith(('assertRaises', 'assertWarns')):
            n_stmt = _method_map[method](*required_args,
                                         indent=find_indentation(node),
                                         kws=argsdict,
                                         arglist=results['arglist'])
        else:
            n_stmt = Node(syms.assert_stmt,
                          [Name('assert'),
                           _method_map[method](*required_args, kws=argsdict)])
        if argsdict.get('msg', None) is not None:
            n_stmt.children.extend((Name(','), argsdict['msg']))
        n_stmt.prefix = node.prefix
        return n_stmt
示例#7
0
def find_indentation(node):
    try:
        from lib2to3.fixer_util import find_indentation
        return find_indentation(node)
    except ImportError:
        while node is not None:
            if node.type == symbols.suite and len(node.children) > 2:
                indent = node.children[1]
                if indent.type == token.INDENT:
                    return indent.value
            node = node.parent
        return ""
示例#8
0
文件: utils.py 项目: sashka/pep8ify
def find_indentation(node):
    try:
        from lib2to3.fixer_util import find_indentation
        return find_indentation(node)
    except ImportError:
        while node is not None:
            if node.type == symbols.suite and len(node.children) > 2:
                indent = node.children[1]
                if indent.type == token.INDENT:
                    return indent.value
            node = node.parent
        return ""
    def transform(self, node, results):
        unifunc = results["unifunc"]
        strfunc = Name("__str__", prefix=unifunc.prefix)
        unifunc.replace(strfunc)

        klass = node.clone()
        klass.prefix = '\n' + find_indentation(node)
        decorator = Node(syms.decorator, [Leaf(token.AT, "@"), Name('python_2_unicode_compatible')])
        decorated = Node(syms.decorated, [decorator, klass], prefix=node.prefix)
        node.replace(decorated)

        touch_import('django.utils.encoding', 'python_2_unicode_compatible', decorated)
    def transform_semi(self, node):
        for child in node.children:
            if child.type == token.SEMI:
                # Strip any whitespace from the next sibling
                if child.next_sibling.prefix != child.next_sibling.prefix.lstrip():
                    child.next_sibling.prefix = child.next_sibling.prefix.lstrip()
                    child.next_sibling.changed()
                # Replace the semi with a newline
                old_depth = find_indentation(child)

                child.replace([Leaf(token.NEWLINE, u"\n"), Leaf(token.INDENT, old_depth)])
                child.changed()
        return node
    def transform_colon(self, node):
        node_copy = node.clone()
        # Strip any whitespace that could have been there
        node_copy.prefix = node_copy.prefix.lstrip()
        old_depth = find_indentation(node)
        new_indent = u'%s%s' % ((u' ' * 4), old_depth)
        new_node = Node(symbols.suite, [Leaf(token.NEWLINE, u'\n'), Leaf(token
            .INDENT, new_indent), node_copy, Leaf(token.DEDENT, u'')])
        node.replace(new_node)
        node.changed()

        # Replace node with new_node in case semi
        return node_copy
示例#12
0
    def transform(self, node, results):
        if "fix_print_with_import" in node.prefix:
            return node

        r = super(FixPrintWithImport, self).transform(node, results)

        if not r or r == node:
            return r

        if not r.prefix:
            indentation = find_indentation(node)
            r.prefix = "# fix_print_with_import\n" + indentation
        else:
            r.prefix = re.sub('([ \t]*$)', r'\1# fix_print_with_import\n\1', r.prefix)

        return r
示例#13
0
    def transform(self, node, results):
        unifunc = results["unifunc"]
        strfunc = Name("__str__", prefix=unifunc.prefix)
        unifunc.replace(strfunc)

        klass = node.clone()
        klass.prefix = '\n' + find_indentation(node)
        decorator = Node(
            syms.decorator,
            [Leaf(token.AT, "@"),
             Name('python_2_unicode_compatible')])
        decorated = Node(syms.decorated, [decorator, klass],
                         prefix=node.prefix)
        node.replace(decorated)

        touch_import('django.utils.encoding', 'python_2_unicode_compatible',
                     decorated)
    def transform(self, node, results):
        if "fix_print_with_import" in node.prefix:
            return node

        r = super(FixPrintWithImport, self).transform(node, results)

        if not r or r == node:
            return r

        if not r.prefix:
            indentation = find_indentation(node)
            r.prefix = "# fix_print_with_import\n" + indentation
        else:
            r.prefix = re.sub('([ \t]*$)', r'\1# fix_print_with_import\n\1',
                              r.prefix)

        return r
示例#15
0
def patch(t):
    visitor = Visitor()
    visitor.visit(t)

    seen = set()
    for node in visitor.r:
        target = node

        while target:
            if node_name(target) == "funcdef":
                if id(target) not in seen:
                    seen.add(id(target))  # xxx:

                    insert_before(target, Decorator("profile"))
                    target.prefix = u.find_indentation(target)
            elif target.parent is None:
                break
            target = target.parent
示例#16
0
def _insert_node(t0, node) -> None:
    ancestors = []
    target = node
    while target is not None:
        ancestors.append(node_name(target))
        target = target.parent

    assert ancestors[-1] == "file_input"
    is_toplevel_def = ancestors[-2] in ("funcdef", "classdef", "async_funcdef")
    if not is_toplevel_def:
        return

    indentation = find_indentation(node)
    if indentation == "":
        node.prefix = "\n\n"
    else:
        node.prefix = f"\n{indentation}"
    t0.append_child(node)
    def transform(self, node, results):
        child_imports = [leaf.value for leaf in results if leaf.type == token.
            NAME]
        current_indentation = find_indentation(node)

        new_nodes = []
        for index, module_name in enumerate(child_imports):
            new_prefix = current_indentation
            if not index:
                # Don't add more indentation if this is the first one
                new_prefix = None
            new_nodes.append(Node(symbols.simple_stmt, [Node(symbols.
                import_name, [Leaf(token.NAME, 'import', prefix=new_prefix),
                Leaf(token.NAME, module_name, prefix=" ")]), Leaf(token.
                NEWLINE, '\n')]))

        node.replace(new_nodes)
        node.changed()
示例#18
0
    def transform(self, node: LN, capture: Capture) -> None:
        imports = []
        for n in capture["module_names"]:
            if n.type == token.COMMA:
                pass
            elif n.type == token.NAME:
                imports.append((n.value, None))
            elif n.type == syms.dotted_name:
                imports.append((traverse_dotted_name(n), None))
            elif n.type == syms.dotted_as_name:
                import_name, import_nick = n.children[0], n.children[2]
                assert n.children[1].type == token.NAME and n.children[
                    1].value == "as"
                imports.append(
                    (traverse_dotted_name(import_name), import_nick.value))

        rename_seen = False
        nodes = []
        indent = find_indentation(node)
        for name, nick in imports:
            try:
                new_name = get_new_name(name)
            except NameRemovedError as exc:
                self.warn(node, str(exc))
                continue

            if new_name:
                rename_seen = True
                name = new_name
            new_node = Node(
                syms.import_name,
                [Name("import"),
                 ImportAsName(name, nick, prefix=" ")],
                prefix=f"\n{indent}",
            )
            nodes.append(new_node)

        if not nodes:
            return

        nodes[0].prefix = node.prefix

        if rename_seen:
            node.replace(nodes)
示例#19
0
    def add_py2_annot(self, argtypes, restype, node, results):
        # type: (List[str], str, Node, Dict[str, Any]) -> None

        children = results['suite'][0].children

        # Insert '# type: {annot}' comment.
        # For reference, see lib2to3/fixes/fix_tuple_params.py in stdlib.
        if len(children) >= 1 and children[0].type != token.NEWLINE:
            # one liner function
            if children[0].prefix.strip() == '':
                children[0].prefix = ''
                children.insert(0, Leaf(token.NEWLINE, '\n'))
                children.insert(
                    1, Leaf(token.INDENT, find_indentation(node) + '    '))
                children.append(Leaf(token.DEDENT, ''))

        if len(children) >= 2 and children[1].type == token.INDENT:
            degen_str = '(...) -> %s' % restype
            short_str = '(%s) -> %s' % (', '.join(argtypes), restype)
            if self.use_py2_long_form(argtypes, short_str, degen_str):
                self.insert_long_form(node, results, argtypes)
                annot_str = degen_str
            else:
                annot_str = short_str

            indent_node = children[1]
            comment, sep, other_comments = indent_node.prefix.partition('\n')
            comment = comment.rstrip() + sep
            annot_str = '# type: %s\n' % (annot_str,)
            if comment == annot_str:
                return

            if comment and not is_type_comment(comment):
                # push existing non-type comment to next line
                annot_str += comment

            indent_node.prefix = indent_node.value + annot_str + other_comments
            indent_node.changed()
        else:
            self.log_message("%s:%d: cannot insert annotation for one-line function" %
                             (self.filename, node.get_lineno()))
    def transform(self, node, results):
        # Determine the node's column number by finding the first leaf.
        leaf = node
        while not isinstance(leaf, Leaf):
            leaf = leaf.children[0]
        # Only match functions and the global indentation level.
        if leaf.column != 0:
            return

        indent = None
        for child in node.children:
            if isinstance(child, Node) and child.type == python_symbols.suite:
                indent = find_indentation(child)
            if isinstance(
                    child, Leaf
            ) and child.type == token.NAME and child.value == self.funcname:
                child.value = self.newname
            elif isinstance(child,
                            Node) and child.type == python_symbols.parameters:
                pre_params = []
                for param in self.pre_params:
                    pre_params.append(Leaf(token.NAME, param))
                    pre_params.append(Leaf(token.COMMA, ', '))
                child.children[1:1] = pre_params
                post_params = []
                for param in self.post_params:
                    post_params.append(Leaf(token.COMMA, ','))
                    post_params.append(Leaf(token.NAME, param))
                child.children[-1:-1] = post_params
                if child.children[-2].type == token.COMMA:
                    child.children.pop(-2)
                child.changed()
        if self.add_statement:
            node.children.append(
                Leaf(0, indent + self.add_statement.rstrip() + '\n'))
        if self.remove:
            self.results.append(node)
            node.replace([])
            return None
        else:
            return node
示例#21
0
    def transform_member(self, node, results):
        """Transform for imports of specific module elements. Replaces
           the module to be imported from with the appropriate new
           module.
        """
        mod_member = results.get("mod_member")
        pref = mod_member.prefix
        member = results.get("member")

        # Simple case with only a single member being imported
        if member:
            # this may be a list of length one, or just a node
            if isinstance(member, list):
                member = member[0]
            new_name = None
            for change in MAPPING[mod_member.value]:
                if member.value in change[1]:
                    new_name = change[0]
                    break
            if new_name:
                mod_member.replace(Name(new_name, prefix=pref))
            else:
                self.cannot_convert(node, "This is an invalid module element")

        # Multiple members being imported
        else:
            # a dictionary for replacements, order matters
            modules = []
            mod_dict = {}
            members = results["members"]
            for member in members:
                # we only care about the actual members
                if member.type == syms.import_as_name:
                    as_name = member.children[2].value
                    member_name = member.children[0].value
                else:
                    member_name = member.value
                    as_name = None
                if member_name != u",":
                    for change in MAPPING[mod_member.value]:
                        if member_name in change[1]:
                            if change[0] not in mod_dict:
                                modules.append(change[0])
                            mod_dict.setdefault(change[0], []).append(member)

            new_nodes = []
            indentation = find_indentation(node)
            first = True

            def handle_name(name, prefix):
                if name.type == syms.import_as_name:
                    kids = [
                        Name(name.children[0].value, prefix=prefix),
                        name.children[1].clone(), name.children[2].clone()
                    ]
                    return [Node(syms.import_as_name, kids)]
                return [Name(name.value, prefix=prefix)]

            for module in modules:
                elts = mod_dict[module]
                names = []
                for elt in elts[:-1]:
                    names.extend(handle_name(elt, pref))
                    names.append(Comma())
                names.extend(handle_name(elts[-1], pref))
                new = FromImport(module, names)
                if not first or node.parent.prefix.endswith(indentation):
                    new.prefix = indentation
                new_nodes.append(new)
                first = False
            if new_nodes:
                nodes = []
                for new_node in new_nodes[:-1]:
                    nodes.extend([new_node, Newline()])
                nodes.append(new_nodes[-1])
                node.replace(nodes)
            else:
                self.cannot_convert(node, "All module elements are invalid")
示例#22
0
    def transform(self, node, results):
        if FixAnnotate.counter is not None:
            if FixAnnotate.counter <= 0:
                return

        # Check if there's already a long-form annotation for some argument.
        parameters = results.get('parameters')
        if parameters is not None:
            for ch in parameters.pre_order():
                if ch.prefix.lstrip().startswith('# type:'):
                    return
        args = results.get('args')
        if args is not None:
            for ch in args.pre_order():
                if ch.prefix.lstrip().startswith('# type:'):
                    return

        suite = results['suite']
        children = suite[0].children

        # NOTE: I've reverse-engineered the structure of the parse tree.
        # It's always a list of nodes, the first of which contains the
        # entire suite.  Its children seem to be:
        #
        #   [0] NEWLINE
        #   [1] INDENT
        #   [2...n-2] statements (the first may be a docstring)
        #   [n-1] DEDENT
        #
        # Comments before the suite are part of the INDENT's prefix.
        #
        # "Compact" functions (e.g. "def foo(x, y): return max(x, y)")
        # have a different structure (no NEWLINE, INDENT, or DEDENT).

        # Check if there's already an annotation.
        for ch in children:
            if ch.prefix.lstrip().startswith('# type:'):
                return  # There's already a # type: comment here; don't change anything.

        # Compute the annotation
        annot = self.make_annotation(node, results)
        if annot is None:
            return

        # Insert '# type: {annot}' comment.
        # For reference, see lib2to3/fixes/fix_tuple_params.py in stdlib.
        if len(children) >= 1 and children[0].type != token.NEWLINE:
            if children[0].prefix.strip() == '':
                children[0].prefix = ''
                children.insert(0, Leaf(token.NEWLINE, '\n'))
                children.insert(
                    1, Leaf(token.INDENT,
                            find_indentation(node) + '    '))
                children.append(Leaf(token.DEDENT, ''))
        if len(children) >= 2 and children[1].type == token.INDENT:
            argtypes, restype = annot
            degen_str = '(...) -> %s' % restype
            short_str = '(%s) -> %s' % (', '.join(argtypes), restype)
            if (len(short_str) > 64
                    or len(argtypes) > 5) and len(short_str) > len(degen_str):
                self.insert_long_form(node, results, argtypes)
                annot_str = degen_str
            else:
                annot_str = short_str
            children[1].prefix = '%s# type: %s\n%s' % (
                children[1].value, annot_str, children[1].prefix)
            children[1].changed()
            if FixAnnotate.counter is not None:
                FixAnnotate.counter -= 1

            # Also add 'from typing import Any' at the top if needed.
            self.patch_imports(argtypes + [restype], node)
        else:
            self.log_message(
                "%s:%d: cannot insert annotation for one-line function" %
                (self.filename, node.get_lineno()))
示例#23
0
    def transform(self, node, results):
        def process_arg(arg):
            if isinstance(arg, Leaf) and arg.type == token.COMMA:
                return
            elif isinstance(arg, Node) and arg.type == syms.argument:
                # keyword argument
                name, equal, value = arg.children
                assert name.type == token.NAME  # what is the symbol for 1?
                assert equal.type == token.EQUAL  # what is the symbol for 1?
                value = value.clone()
                value.prefix = " "
                kwargs[name.value] = value
            else:
                assert not kwargs, 'all positional args are assumed to come first'
                posargs.append(arg.clone())

        func = results['func'].value

        custom_helper = False

        if node.parent.type == syms.return_stmt:
            # custom helper with `return eq_(...)`
            # We're not rendering the `assert` in that case
            # to allow the code to continue functioning
            custom_helper = True

        posargs = []
        kwargs = []

        if results['arglist'].type == syms.arglist:
            for arg in results['arglist'].children:
                process_arg(arg)
        else:
            process_arg(results['arglist'])

        if len(posargs) == 2:
            left, right = posargs
        elif len(posargs) == 3:
            left, right, _ = posargs

        left.prefix = " "
        right.prefix = " "

        strip_newlines(left)
        strip_newlines(right)

        # Ignore customized assert messages for now
        if isinstance(right,
                      Leaf) and right.value in ('None', 'True', 'False'):
            op = Name('is', prefix=' ')
            body = [Node(syms.comparison, (left, op, right))]
        else:
            op = Name('==', prefix=' ')
            body = [Node(syms.comparison, (left, op, right))]

        indent = find_indentation(node)

        ret = Name('assert')

        if node.parent.prefix.endswith(indent):
            ret.prefix = indent

        if custom_helper:
            return body

        return [ret] + body
示例#24
0
def is_top_level(node):
    """Is node at top indentation level (i.e. module globals)?"""
    return bool(len(find_indentation(node)))
示例#25
0
def is_top_level(node):
    """Is node at top indentation level (module globals)"""
    return 0 == len(find_indentation(node))
    def transform(self, node, results):
        def process_arg(arg):
            if isinstance(arg, Leaf) and arg.type == token.COMMA:
                return
            elif (isinstance(arg, Node) and arg.type == syms.argument
                  and arg.children[1].type == token.EQUAL):
                # keyword argument
                name, equal, value = arg.children
                assert name.type == token.NAME
                assert equal.type == token.EQUAL
                value = value.clone()
                kwargs[name.value] = value
                if '\n' in arg.prefix:
                    value.prefix = arg.prefix
                else:
                    value.prefix = arg.prefix.strip() + " "
            else:
                if (isinstance(arg, Node) and arg.type == syms.argument
                        and arg.children[0].type == 36
                        and arg.children[0].value == '**'):
                    return
                assert not kwargs, 'all positional args are assumed to come first'
                if (isinstance(arg, Node) and arg.type == syms.argument
                        and arg.children[1].type == syms.comp_for):
                    # argument is a generator expression w/o
                    # parenthesis, add parenthesis
                    value = arg.clone()
                    value.children.insert(0, Leaf(token.LPAR, '('))
                    value.children.append(Leaf(token.RPAR, ')'))
                    posargs.append(value)
                else:
                    posargs.append(arg.clone())

        method = results['method'][0].value
        # map (deprecated) aliases to original to avoid analysing
        # the decorator function
        method = _method_aliases.get(method, method)

        posargs = []
        kwargs = {}

        # This is either a "arglist" or a single argument
        if results['arglist'].type == syms.arglist:
            for arg in results['arglist'].children:
                process_arg(arg)
        else:
            process_arg(results['arglist'])

        try:
            test_func = getattr(unittest.TestCase, method)
        except AttributeError:
            raise RuntimeError("Your unittest package does not support '%s'. "
                               "consider updating the package" % method)

        required_args, argsdict = utils.resolve_func_args(
            test_func, posargs, kwargs)

        if method.startswith(('assertRaises', 'assertWarns')):
            n_stmt = _method_map[method](*required_args,
                                         indent=find_indentation(node),
                                         kws=argsdict,
                                         arglist=results['arglist'],
                                         node=node)
        else:
            n_stmt = Node(syms.assert_stmt, [
                Name('assert'), _method_map[method](*required_args,
                                                    kws=argsdict)
            ])
        if argsdict.get('msg', None) is not None:
            n_stmt.children.extend((Name(','), argsdict['msg']))

        def fix_line_wrapping(x):
            for c in x.children:
                # no need to worry about wrapping of "[", "{" and "("
                if c.type in [token.LSQB, token.LBRACE, token.LPAR]:
                    break
                if c.prefix.startswith('\n'):
                    c.prefix = c.prefix.replace('\n', ' \\\n')
                fix_line_wrapping(c)

        fix_line_wrapping(n_stmt)
        # the prefix should be set only after fixing line wrapping because it can contain a '\n'
        n_stmt.prefix = node.prefix

        # add necessary imports
        if 'Raises' in method or 'Warns' in method:
            add_import('pytest', node)
        if ('Regex' in method and not 'Raises' in method
                and not 'Warns' in method):
            add_import('re', node)

        return n_stmt
    def transform(self, node, results):

        def process_arg(arg):
            if isinstance(arg, Leaf) and arg.type == token.COMMA:
                return
            elif isinstance(arg, Node) and arg.type == syms.argument:
                # keyword argument
                name, equal, value = arg.children
                assert name.type == token.NAME # what is the symbol for 1?
                assert equal.type == token.EQUAL # what is the symbol for 1?
                value = value.clone()
                kwargs[name.value] = value
                if '\n' in arg.prefix:
                    value.prefix = arg.prefix
                else:
                    value.prefix = arg.prefix.strip() + " "
            else:
                assert not kwargs, 'all positional args are assumed to come first'
                posargs.append(arg.clone())

        method = results['method'][0].value
        # map (deprecated) aliases to original to avoid analysing
        # the decorator function
        method = _method_aliases.get(method, method)

        posargs = []
        kwargs = {}

        # This is either a "arglist" or a single argument
        if results['arglist'].type == syms.arglist:
            for arg in results['arglist'].children:
                process_arg(arg)
        else:
            process_arg(results['arglist'])

        try:
            test_func = getattr(unittest.TestCase, method)
        except AttributeError:
            raise RuntimeError("Your unittest package does not support '%s'. "
                               "consider updating the package" % method)

        required_args, argsdict = utils.resolve_func_args(test_func, posargs, kwargs)

        if method.startswith(('assertRaises', 'assertWarns')):
            n_stmt = _method_map[method](*required_args,
                                         indent=find_indentation(node),
                                         kws=argsdict,
                                         arglist=results['arglist'],
                                         node=node)
        else:
            n_stmt = Node(syms.assert_stmt,
                          [Name('assert'),
                           _method_map[method](*required_args, kws=argsdict)])
        if argsdict.get('msg', None) is not None:
            n_stmt.children.extend((Name(','), argsdict['msg']))

        def fix_line_wrapping(x):
            for c in x.children:
                # no need to worry about wrapping of "[", "{" and "("
                if c.type in [token.LSQB, token.LBRACE, token.LPAR]:
                    break
                if c.prefix.startswith('\n'):
                    c.prefix = c.prefix.replace('\n', ' \\\n')
                fix_line_wrapping(c)
        fix_line_wrapping(n_stmt)
        # the prefix should be set only after fixing line wrapping because it can contain a '\n'
        n_stmt.prefix = node.prefix

        # add necessary imports
        if 'Raises' in method or 'Warns' in method:
            add_import('pytest', node)
        if 'Regex' in method:
            add_import('re', node)

        return n_stmt
示例#28
0
    def transform(self, node: LN, capture: Capture) -> None:
        module_name = traverse_dotted_name(capture["module_name"])

        to_process = capture["module_imports"]
        imports: List[Tuple[str, Optional[str]]] = []
        for n in to_process:
            if n.type in (token.COMMA, token.LPAR, token.RPAR):
                continue
            elif n.type == token.STAR:
                self.warn(n, "Cannot guarantee * imports are correct.")
                imports.append(("*", None))
            elif n.type == token.NAME:
                imports.append((n.value, None))
            elif n.type == syms.import_as_name:
                import_name, import_nick = n.children[0], n.children[2]
                assert n.children[1].type == token.NAME and n.children[
                    1].value == "as"
                imports.append((import_name.value,
                                import_nick.value if import_nick else None))
            elif n.type == syms.import_as_names:
                to_process.extend(n.children)
            else:
                raise Exception(f"unexpected node {repr(n)}")

        imports_by_package: DefaultDict[str, List[Tuple[
            str, Optional[str]]]] = collections.defaultdict(list)
        for name, nick in imports:
            full_name = f"{module_name}.{name}"
            try:
                new_full_name = get_new_name(full_name) or full_name
            except NameRemovedError as exc:
                self.warn(node, str(exc))
                continue
            package, new_name = split_package_and_name(new_full_name)
            if name != new_name and nick is None:
                nick = name
            imports_by_package[package].append((new_name, nick))

        indent = find_indentation(node)
        nodes = []
        for package, imports in sorted(imports_by_package.items(),
                                       key=lambda i: i[0]):
            if package:
                result = FromImport(package, imports, prefix=f"\n{indent}")
                nodes.append(result)
            else:
                for name, nick in imports:
                    nodes.append(
                        Node(
                            syms.import_name,
                            [
                                Name("import"),
                                ImportAsName(name, nick, prefix=" ")
                            ],
                            prefix=f"\n{indent}",
                        ))

        if not nodes:
            return

        nodes[0].prefix = node.prefix
        node.replace(nodes)
示例#29
0
    def transform_member(self, node, results):
        """Transform for imports of specific module elements. Replaces
           the module to be imported from with the appropriate new
           module.
        """
        mod_member = results.get("mod_member")
        pref = mod_member.prefix
        member = results.get("member")

        # Simple case with only a single member being imported
        if member:
            # this may be a list of length one, or just a node
            if isinstance(member, list):
                member = member[0]
            new_name = None
            for change in MAPPING[mod_member.value]:
                if member.value in change[1]:
                    new_name = change[0]
                    break
            if new_name:
                mod_member.replace(Name(new_name, prefix=pref))
            else:
                self.cannot_convert(node, "This is an invalid module element")

        # Multiple members being imported
        else:
            # a dictionary for replacements, order matters
            modules = []
            mod_dict = {}
            members = results["members"]
            for member in members:
                # we only care about the actual members
                if member.type == syms.import_as_name:
                    as_name = member.children[2].value
                    member_name = member.children[0].value
                else:
                    member_name = member.value
                    as_name = None
                if member_name != ",":
                    for change in MAPPING[mod_member.value]:
                        if member_name in change[1]:
                            if change[0] not in mod_dict:
                                modules.append(change[0])
                            mod_dict.setdefault(change[0], []).append(member)

            new_nodes = []
            indentation = find_indentation(node)
            first = True
            def handle_name(name, prefix):
                if name.type == syms.import_as_name:
                    kids = [Name(name.children[0].value, prefix=prefix),
                            name.children[1].clone(),
                            name.children[2].clone()]
                    return [Node(syms.import_as_name, kids)]
                return [Name(name.value, prefix=prefix)]
            for module in modules:
                elts = mod_dict[module]
                names = []
                for elt in elts[:-1]:
                    names.extend(handle_name(elt, pref))
                    names.append(Comma())
                names.extend(handle_name(elts[-1], pref))
                new = FromImport(module, names)
                if not first or node.parent.prefix.endswith(indentation):
                    new.prefix = indentation
                new_nodes.append(new)
                first = False
            if new_nodes:
                nodes = []
                for new_node in new_nodes[:-1]:
                    nodes.extend([new_node, Newline()])
                nodes.append(new_nodes[-1])
                node.replace(nodes)
            else:
                self.cannot_convert(node, "All module elements are invalid")
示例#30
0
def f(x):

    def g(y):
        return y + y
    return g(x + 1)
"""


def Decorator(name):
    return u.Node(
        278, [u.Leaf(token.AT, "@", prefix=None),
              u.Name(name),
              u.Newline()])


def insert_before(node, new_node):
    for i, x in enumerate(node.parent.children):
        if x == node:
            node.parent.insert_child(i, new_node)
            return True
    return False


t = parse_string(code)
for defs in run(t):
    insert_before(defs[0], Decorator("profile"))
    if not defs[0].prefix:
        defs[0].prefix = u.find_indentation(defs[0])
    break
print(t)
示例#31
0
    def transform_member(self, node, results):
        """Transform for imports of specific module elements. Replaces
           the module to be imported from with the appropriate new
           module.
        """
        mod_member = results.get('mod_member')
        pref = mod_member.prefix
        member = results.get('member')
        if member:
            if isinstance(member, list):
                member = member[0]
            new_name = None
            for change in MAPPING[mod_member.value]:
                if member.value in change[1]:
                    new_name = change[0]
                    break

            if new_name:
                mod_member.replace(Name(new_name, prefix=pref))
            else:
                self.cannot_convert(node, 'This is an invalid module element')
        else:
            modules = []
            mod_dict = {}
            members = results['members']
            for member in members:
                if member.type == syms.import_as_name:
                    as_name = member.children[2].value
                    member_name = member.children[0].value
                else:
                    member_name = member.value
                    as_name = None
                if member_name != u',':
                    for change in MAPPING[mod_member.value]:
                        if member_name in change[1]:
                            if change[0] not in mod_dict:
                                modules.append(change[0])
                            mod_dict.setdefault(change[0], []).append(member)

            new_nodes = []
            indentation = find_indentation(node)
            first = True

            def handle_name(name, prefix):
                if name.type == syms.import_as_name:
                    kids = [Name(name.children[0].value, prefix=prefix), name.children[1].clone(), name.children[2].clone()]
                    return [Node(syms.import_as_name, kids)]
                return [Name(name.value, prefix=prefix)]

            for module in modules:
                elts = mod_dict[module]
                names = []
                for elt in elts[:-1]:
                    names.extend(handle_name(elt, pref))
                    names.append(Comma())

                names.extend(handle_name(elts[-1], pref))
                new = FromImport(module, names)
                if not first or node.parent.prefix.endswith(indentation):
                    new.prefix = indentation
                new_nodes.append(new)
                first = False

            if new_nodes:
                nodes = []
                for new_node in new_nodes[:-1]:
                    nodes.extend([new_node, Newline()])

                nodes.append(new_nodes[-1])
                node.replace(nodes)
            else:
                self.cannot_convert(node, 'All module elements are invalid')
        return
示例#32
0
    def transform(self, node, results):
        def process_arg(arg):
            if isinstance(arg, Leaf) and arg.type == token.COMMA:
                return
            elif isinstance(arg, Node) and arg.type == syms.argument:
                # keyword argument
                name, equal, value = arg.children
                assert name.type == token.NAME # what is the symbol for 1?
                assert equal.type == token.EQUAL # what is the symbol for 1?
                value = value.clone()
                value.prefix = " "
                kwargs[name.value] = value
            else:
                assert not kwargs, 'all positional args are assumed to come first'
                posargs.append(arg.clone())

        func = results['func'].value

        custom_helper = False

        if node.parent.type == syms.return_stmt:
            # custom helper with `return eq_(...)`
            # We're not rendering the `assert` in that case
            # to allow the code to continue functioning
            custom_helper = True

        posargs = []
        kwargs = []

        if results['arglist'].type == syms.arglist:
            for arg in results['arglist'].children:
                process_arg(arg)
        else:
            process_arg(results['arglist'])

        if len(posargs) == 2:
            left, right = posargs
        elif len(posargs) == 3:
            left, right, _ = posargs

        left.prefix = " "
        right.prefix = " "

        strip_newlines(left)
        strip_newlines(right)

        # Ignore customized assert messages for now
        if isinstance(right, Leaf) and right.value in ('None', 'True', 'False'):
            op = Name('is', prefix=' ')
            body = [Node(syms.comparison, (left, op, right))]
        else:
            op = Name('==', prefix=' ')
            body = [Node(syms.comparison, (left, op, right))]

        indent = find_indentation(node)

        ret = Name('assert')

        if node.parent.prefix.endswith(indent):
            ret.prefix = indent

        if custom_helper:
            return body

        return [ret] + body
示例#33
0
def is_top_level(node):
    """Is node at top indentation level (module globals)"""
    return 0 == len(find_indentation(node))
示例#34
0
    def fix_leaves(self, node_to_split):
        parent_depth = find_indentation(node_to_split)
        new_indent = "%s%s" % (' ' * 4, parent_depth)
        # For now, just indent additional lines by 4 more spaces

        child_leaves = []
        combined_prefix = ""
        prev_leaf = None
        for index, leaf in enumerate(node_to_split.leaves()):
            if index and leaf.prefix.count('#'):
                if not combined_prefix:
                    combined_prefix = "%s#" % new_indent
                combined_prefix += leaf.prefix.split('#')[-1]

            # We want to strip all newlines so we can properly insert newlines
            # where they should be
            if leaf.type != token.NEWLINE:
                if leaf.prefix.count('\n') and index:
                    # If the line contains a newline, we need to strip all
                    # whitespace since there were leading indent spaces
                    if (prev_leaf and prev_leaf.type in [token.DOT, token.LPAR]
                        or leaf.type in [token.RPAR]):
                        leaf.prefix = ""
                    else:
                        leaf.prefix = " "

                    # Append any trailing inline comments to the combined
                    # prefix
                child_leaves.append(leaf)
                prev_leaf = leaf

        # Like TextWrapper, but for nodes. We split on MAX_CHARS - 1 since we
        # may need to insert a leading parenth. It's not great, but it would be
        # hard to do properly.
        split_leaves = wrap_leaves(child_leaves, width=MAX_CHARS - 1,
            subsequent_indent=new_indent)
        new_node = Node(node_to_split.type, [])

        # We want to keep track of if we are breaking inside a parenth
        open_count = 0
        need_parens = False
        for line_index, curr_line_nodes in enumerate(split_leaves):
            for node_index, curr_line_node in enumerate(curr_line_nodes):
                if line_index and not node_index:
                    # If first node in non-first line, reset prefix since there
                    # may have been spaces previously
                    curr_line_node.prefix = new_indent
                new_node.append_child(curr_line_node)
                if curr_line_node.type in OPENING_TOKENS:
                    open_count += 1
                if curr_line_node.type in CLOSING_TOKENS:
                    open_count -= 1

            if line_index != len(split_leaves) - 1:
                # Don't add newline at the end since it it part of the next
                # sibling
                new_node.append_child(Leaf(token.NEWLINE, '\n'))

                # Checks if we ended a line without being surrounded by parens
                if open_count <= 0:
                    need_parens = True
        if need_parens:
            # Parenthesize the parent if we're not inside parenths, braces,
            # brackets, since we inserted newlines between leaves
            parenth_before_equals = Leaf(token.EQUAL, "=") in split_leaves[0]
            self.parenthesize_parent(new_node, parenth_before_equals)
        node_to_split.replace(new_node)

        return combined_prefix