Esempio n. 1
0
 def leave_AnnAssign(
     self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign
 ) -> Union[cst.BaseSmallStatement, cst.RemovalSentinel]:
     # It handles a special case where a type-annotated variable has not initialized, e.g. foo: str
     # This case will be converted to foo = ... so that nodes traversal won't encounter exceptions later on
     if match.matches(
             original_node,
             match.AnnAssign(
                 target=match.Name(value=match.DoNotCare()),
                 annotation=match.Annotation(annotation=match.DoNotCare()),
                 value=None)):
         updated_node = cst.Assign(
             targets=[cst.AssignTarget(target=original_node.target)],
             value=cst.Ellipsis())
     # Handles type-annotated class attributes that has not been initialized, e.g. self.foo: str
     elif match.matches(
             original_node,
             match.AnnAssign(
                 target=match.Attribute(value=match.DoNotCare()),
                 annotation=match.Annotation(annotation=match.DoNotCare()),
                 value=None)):
         updated_node = cst.Assign(
             targets=[cst.AssignTarget(target=original_node.target)],
             value=cst.Ellipsis())
     else:
         updated_node = cst.Assign(
             targets=[cst.AssignTarget(target=original_node.target)],
             value=original_node.value)
     return updated_node
class StripStringsCommand(VisitorBasedCodemodCommand):

    DESCRIPTION: str = (
        "Converts string type annotations to 3.7-compatible forward references."
    )

    METADATA_DEPENDENCIES = (QualifiedNameProvider, )

    # We want to gate the SimpleString visitor below to only SimpleStrings inside
    # an Annotation.
    @m.call_if_inside(m.Annotation())
    # We also want to gate the SimpleString visitor below to ensure that we don't
    # erroneously strip strings from a Literal.
    @m.call_if_not_inside(
        m.Subscript(
            # We could match on value=m.Name("Literal") here, but then we might miss
            # instances where people are importing typing_extensions directly, or
            # importing Literal as an alias.
            value=m.MatchMetadataIfTrue(
                QualifiedNameProvider,
                lambda qualnames: any(qualname.name ==
                                      "typing_extensions.Literal"
                                      for qualname in qualnames),
            )))
    def leave_SimpleString(
        self, original_node: libcst.SimpleString,
        updated_node: libcst.SimpleString
    ) -> Union[libcst.SimpleString, libcst.BaseExpression]:
        AddImportsVisitor.add_needed_import(self.context, "__future__",
                                            "annotations")
        # Just use LibCST to evaluate the expression itself, and insert that as the
        # annotation.
        return parse_expression(updated_node.evaluated_value,
                                config=self.module.config_for_parsing)
Esempio n. 3
0
 def visit_Annotation(self, node: cst.Annotation):
     if not match.matches(
             node,
             match.Annotation(
                 annotation=match.OneOf(match.Name(
                     value='None'), match.List(elements=[])))):
         self.type_annot_visited = True
Esempio n. 4
0
class DoubleQuoteForwardRefsTransformer(m.MatcherDecoratableTransformer):
    @m.call_if_inside(m.Annotation())
    def leave_SimpleString(self, original_node: cst.SimpleString,
                           updated_node: cst.SimpleString) -> cst.SimpleString:
        # For prettiness, convert all single-quoted forward refs to double-quoted.
        if updated_node.value.startswith("'") and updated_node.value.endswith(
                "'"):
            return updated_node.with_changes(
                value=f'"{updated_node.value[1:-1]}"')
        return updated_node
Esempio n. 5
0
class DoubleQuoteForwardRefsTransformer(m.MatcherDecoratableTransformer):
    @m.call_if_inside(m.Annotation())
    def leave_SimpleString(self, original_node: cst.SimpleString,
                           updated_node: cst.SimpleString) -> cst.SimpleString:
        # For prettiness, convert all single-quoted forward refs to double-quoted.
        if "'" in updated_node.quote:
            new_value = f'"{updated_node.value[1:-1]}"'
            try:
                if updated_node.evaluated_value == ast.literal_eval(new_value):
                    return updated_node.with_changes(value=new_value)
            except SyntaxError:
                pass
        return updated_node
class StripStringsCommand(VisitorBasedCodemodCommand):

    DESCRIPTION: str = "Converts string type annotations to 3.7-compatible forward references."

    @m.call_if_inside(m.Annotation())
    @m.call_if_not_inside(m.Subscript(m.Name("Literal")))
    def leave_SimpleString(
        self, original_node: libcst.SimpleString, updated_node: libcst.SimpleString
    ) -> Union[libcst.SimpleString, libcst.BaseExpression]:
        AddImportsVisitor.add_needed_import(self.context, "__future__", "annotations")
        return parse_expression(
            literal_eval(updated_node.value), config=self.module.config_for_parsing
        )
 def contains_union_with_none(self, node: cst.Annotation) -> bool:
     return m.matches(
         node,
         m.Annotation(
             m.Subscript(
                 value=m.Name("Union"),
                 slice=m.OneOf(
                     [
                         m.SubscriptElement(m.Index()),
                         m.SubscriptElement(m.Index(m.Name("None"))),
                     ],
                     [
                         m.SubscriptElement(m.Index(m.Name("None"))),
                         m.SubscriptElement(m.Index()),
                     ],
                 ),
             )),
     )
Esempio n. 8
0
class ShedFixers(VisitorBasedCodemodCommand):
    """Fix a variety of small problems.

    Replaces `raise NotImplemented` with `raise NotImplementedError`,
    and converts always-failing assert statements to explicit `raise` statements.

    Also includes code closely modelled on pybetter's fixers, because it's
    considerably faster to run all transforms in a single pass if possible.
    """

    DESCRIPTION = "Fix a variety of style, performance, and correctness issues."

    @m.call_if_inside(m.Raise(exc=m.Name(value="NotImplemented")))
    def leave_Name(self, _, updated_node):  # noqa
        return updated_node.with_changes(value="NotImplementedError")

    def leave_Assert(self, _, updated_node):  # noqa
        test_code = cst.Module("").code_for_node(updated_node.test)
        try:
            test_literal = literal_eval(test_code)
        except Exception:
            return updated_node
        if test_literal:
            return cst.RemovalSentinel.REMOVE
        if updated_node.msg is None:
            return cst.Raise(cst.Name("AssertionError"))
        return cst.Raise(
            cst.Call(cst.Name("AssertionError"),
                     args=[cst.Arg(updated_node.msg)]))

    @m.leave(
        m.ComparisonTarget(comparator=oneof_names("None", "False", "True"),
                           operator=m.Equal()))
    def convert_none_cmp(self, _, updated_node):
        """Inspired by Pybetter."""
        return updated_node.with_changes(operator=cst.Is())

    @m.leave(
        m.UnaryOperation(
            operator=m.Not(),
            expression=m.Comparison(
                comparisons=[m.ComparisonTarget(operator=m.In())]),
        ))
    def replace_not_in_condition(self, _, updated_node):
        """Also inspired by Pybetter."""
        expr = cst.ensure_type(updated_node.expression, cst.Comparison)
        return cst.Comparison(
            left=expr.left,
            lpar=updated_node.lpar,
            rpar=updated_node.rpar,
            comparisons=[
                expr.comparisons[0].with_changes(operator=cst.NotIn())
            ],
        )

    @m.leave(
        m.Call(
            lpar=[m.AtLeastN(n=1, matcher=m.LeftParen())],
            rpar=[m.AtLeastN(n=1, matcher=m.RightParen())],
        ))
    def remove_pointless_parens_around_call(self, _, updated_node):
        # This is *probably* valid, but we might have e.g. a multi-line parenthesised
        # chain of attribute accesses ("fluent interface"), where we need the parens.
        noparens = updated_node.with_changes(lpar=[], rpar=[])
        try:
            compile(self.module.code_for_node(noparens), "<string>", "eval")
            return noparens
        except SyntaxError:
            return updated_node

    # The following methods fix https://pypi.org/project/flake8-comprehensions/

    @m.leave(m.Call(func=m.Name("list"), args=[m.Arg(m.GeneratorExp())]))
    def replace_generator_in_call_with_comprehension(self, _, updated_node):
        """Fix flake8-comprehensions C400-402 and 403-404.

        C400-402: Unnecessary generator - rewrite as a <list/set/dict> comprehension.
        Note that set and dict conversions are handled by pyupgrade!
        """
        return cst.ListComp(elt=updated_node.args[0].value.elt,
                            for_in=updated_node.args[0].value.for_in)

    @m.leave(
        m.Call(func=m.Name("list"), args=[m.Arg(m.ListComp(), star="")])
        | m.Call(func=m.Name("set"), args=[m.Arg(m.SetComp(), star="")])
        | m.Call(
            func=m.Name("list"),
            args=[m.Arg(m.Call(func=oneof_names("sorted", "list")), star="")],
        ))
    def replace_unnecessary_list_around_sorted(self, _, updated_node):
        """Fix flake8-comprehensions C411 and C413.

        Unnecessary <list/reversed> call around sorted().

        Also covers C411 Unnecessary list call around list comprehension
        for lists and sets.
        """
        return updated_node.args[0].value

    @m.leave(
        m.Call(
            func=m.Name("reversed"),
            args=[m.Arg(m.Call(func=m.Name("sorted")), star="")],
        ))
    def replace_unnecessary_reversed_around_sorted(self, _, updated_node):
        """Fix flake8-comprehensions C413.

        Unnecessary reversed call around sorted().
        """
        call = updated_node.args[0].value
        args = list(call.args)
        for i, arg in enumerate(args):
            if m.matches(arg.keyword, m.Name("reverse")):
                try:
                    val = bool(
                        literal_eval(self.module.code_for_node(arg.value)))
                except Exception:
                    args[i] = arg.with_changes(
                        value=cst.UnaryOperation(cst.Not(), arg.value))
                else:
                    if not val:
                        args[i] = arg.with_changes(value=cst.Name("True"))
                    else:
                        del args[i]
                        args[i - 1] = remove_trailing_comma(args[i - 1])
                break
        else:
            args.append(
                cst.Arg(keyword=cst.Name("reverse"), value=cst.Name("True")))
        return call.with_changes(args=args)

    _sets = oneof_names("set", "frozenset")
    _seqs = oneof_names("list", "reversed", "sorted", "tuple")

    @m.leave(
        m.Call(func=_sets, args=[m.Arg(m.Call(func=_sets | _seqs), star="")])
        | m.Call(
            func=oneof_names("list", "tuple"),
            args=[m.Arg(m.Call(func=oneof_names("list", "tuple")), star="")],
        )
        | m.Call(
            func=m.Name("sorted"),
            args=[m.Arg(m.Call(func=_seqs), star=""),
                  m.ZeroOrMore()],
        ))
    def replace_unnecessary_nested_calls(self, _, updated_node):
        """Fix flake8-comprehensions C414.

        Unnecessary <list/reversed/sorted/tuple> call within <list/set/sorted/tuple>()..
        """
        return updated_node.with_changes(
            args=[cst.Arg(updated_node.args[0].value.args[0].value)] +
            list(updated_node.args[1:]), )

    @m.leave(
        m.Call(
            func=oneof_names("reversed", "set", "sorted"),
            args=[
                m.Arg(m.Subscript(slice=[m.SubscriptElement(ALL_ELEMS_SLICE)]))
            ],
        ))
    def replace_unnecessary_subscript_reversal(self, _, updated_node):
        """Fix flake8-comprehensions C415.

        Unnecessary subscript reversal of iterable within <reversed/set/sorted>().
        """
        return updated_node.with_changes(
            args=[cst.Arg(updated_node.args[0].value.value)], )

    @m.leave(
        multi(
            m.ListComp,
            m.SetComp,
            elt=m.Name(),
            for_in=m.CompFor(target=m.Name(),
                             ifs=[],
                             inner_for_in=None,
                             asynchronous=None),
        ))
    def replace_unnecessary_listcomp_or_setcomp(self, _, updated_node):
        """Fix flake8-comprehensions C416.

        Unnecessary <list/set> comprehension - rewrite using <list/set>().
        """
        if updated_node.elt.value == updated_node.for_in.target.value:
            func = cst.Name(
                "list" if isinstance(updated_node, cst.ListComp) else "set")
            return cst.Call(func=func,
                            args=[cst.Arg(updated_node.for_in.iter)])
        return updated_node

    @m.leave(m.Subscript(oneof_names("Union", "Literal")))
    def reorder_union_literal_contents_none_last(self, _, updated_node):
        subscript = list(updated_node.slice)
        try:
            subscript.sort(key=lambda elt: elt.slice.value.value == "None")
            subscript[-1] = remove_trailing_comma(subscript[-1])
            return updated_node.with_changes(slice=subscript)
        except Exception:  # Single-element literals are not slices, etc.
            return updated_node

    @m.call_if_inside(m.Annotation(annotation=m.BinaryOperation()))
    @m.leave(
        m.BinaryOperation(
            left=m.Name("None") | m.BinaryOperation(),
            operator=m.BitOr(),
            right=m.DoNotCare(),
        ))
    def reorder_union_operator_contents_none_last(self, _, updated_node):
        def _has_none(node):
            if m.matches(node, m.Name("None")):
                return True
            elif m.matches(node, m.BinaryOperation()):
                return _has_none(node.left) or _has_none(node.right)
            else:
                return False

        node_left = updated_node.left
        if _has_none(node_left):
            return updated_node.with_changes(left=updated_node.right,
                                             right=node_left)
        else:
            return updated_node

    @m.leave(m.Subscript(value=m.Name("Literal")))
    def flatten_literal_subscript(self, _, updated_node):
        new_slice = []
        for item in updated_node.slice:
            if m.matches(item.slice.value, m.Subscript(m.Name("Literal"))):
                new_slice += item.slice.value.slice
            else:
                new_slice.append(item)
        return updated_node.with_changes(slice=new_slice)

    @m.leave(m.Subscript(value=m.Name("Union")))
    def flatten_union_subscript(self, _, updated_node):
        new_slice = []
        has_none = False
        for item in updated_node.slice:
            if m.matches(item.slice.value, m.Subscript(m.Name("Optional"))):
                new_slice += item.slice.value.slice  # peel off "Optional"
                has_none = True
            elif m.matches(item.slice.value,
                           m.Subscript(m.Name("Union"))) and m.matches(
                               updated_node.value, item.slice.value.value):
                new_slice += item.slice.value.slice  # peel off "Union" or "Literal"
            elif m.matches(item.slice.value, m.Name("None")):
                has_none = True
            else:
                new_slice.append(item)
        if has_none:
            new_slice.append(
                cst.SubscriptElement(slice=cst.Index(cst.Name("None"))))
        return updated_node.with_changes(slice=new_slice)

    @m.leave(m.Else(m.IndentedBlock([m.SimpleStatementLine([m.Pass()])])))
    def discard_empty_else_blocks(self, _, updated_node):
        # An `else: pass` block can always simply be discarded, and libcst ensures
        # that an Else node can only ever occur attached to an If, While, For, or Try
        # node; in each case `None` is the valid way to represent "no else block".
        if m.findall(updated_node, m.Comment()):
            return updated_node  # If there are any comments, keep the node
        return cst.RemoveFromParent()

    @m.leave(
        m.Lambda(params=m.MatchIfTrue(lambda node: (
            node.star_kwarg is None and not node.kwonly_params and not node.
            posonly_params and isinstance(node.star_arg, cst.MaybeSentinel) and
            all(param.default is None for param in node.params)))))
    def remove_lambda_indirection(self, _, updated_node):
        same_args = [
            m.Arg(m.Name(param.name.value), star="", keyword=None)
            for param in updated_node.params.params
        ]
        if m.matches(updated_node.body, m.Call(args=same_args)):
            return cst.ensure_type(updated_node.body, cst.Call).func
        return updated_node

    @m.leave(
        m.BooleanOperation(
            left=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]),
            operator=m.Or(),
            right=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]),
        ))
    def collapse_isinstance_checks(self, _, updated_node):
        left_target, left_type = updated_node.left.args
        right_target, right_type = updated_node.right.args
        if left_target.deep_equals(right_target):
            merged_type = cst.Arg(
                cst.Tuple([
                    cst.Element(left_type.value),
                    cst.Element(right_type.value)
                ]))
            return updated_node.left.with_changes(
                args=[left_target, merged_type])
        return updated_node
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Set, Union, cast

import libcst as cst
import libcst.matchers as m
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareVisitor
from libcst.metadata import MetadataWrapper, QualifiedNameProvider

FUNCS_CONSIDERED_AS_STRING_ANNOTATIONS = {"typing.TypeVar"}
ANNOTATION_MATCHER: m.BaseMatcherNode = m.Annotation() | m.Call(
    metadata=m.MatchMetadataIfTrue(
        QualifiedNameProvider,
        lambda qualnames: any(qn.name in FUNCS_CONSIDERED_AS_STRING_ANNOTATIONS
                              for qn in qualnames),
    ))


class GatherNamesFromStringAnnotationsVisitor(ContextAwareVisitor):
    """
    Collects all names from string literals used for typing purposes.
    This includes annotations like ``foo: "SomeType"``, and parameters to
    special functions related to typing (currently only `typing.TypeVar`).

    After visiting, a set of all found names will be available on the ``names``
    attribute of this visitor.
    """