コード例 #1
0
ファイル: test_extract.py プロジェクト: annieliu10/Euphoria
    def test_extract_sequence_element(self) -> None:
        # Verify true behavior
        expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)")
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.DoNotCare(),
                m.Element(
                    m.Call(args=[m.SaveMatchedNode(m.ZeroOrMore(), "args")])),
            ]),
        )
        extracted_seq = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[1].value,
            cst.Call).args
        self.assertEqual(nodes, {"args": extracted_seq})

        # Verify false behavior
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.DoNotCare(),
                m.Element(
                    m.Call(args=[
                        m.SaveMatchedNode(m.ZeroOrMore(m.Arg(m.Subscript())),
                                          "args")
                    ])),
            ]),
        )
        self.assertIsNone(nodes)
コード例 #2
0
def has_footer_comment(body):
    return m.matches(
        body,
        m.IndentedBlock(footer=[
            m.ZeroOrMore(),
            m.EmptyLine(comment=m.Comment()),
            m.ZeroOrMore()
        ]),
    )
コード例 #3
0
ファイル: test_extract.py プロジェクト: annieliu10/Euphoria
 def test_extract_optional_wildcard_tail(self) -> None:
     expression = cst.parse_expression("[3]")
     nodes = m.extract(
         expression,
         m.List(elements=[
             m.Element(value=m.Integer(value="3")),
             m.SaveMatchedNode(m.ZeroOrMore(), "tail1"),
             m.SaveMatchedNode(m.ZeroOrMore(), "tail2"),
         ]),
     )
     self.assertEqual(nodes, {"tail1": (), "tail2": ()})
コード例 #4
0
 def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
     if m.matches(
             node,
             m.ImportFrom(
                 module=m.Name("__future__"),
                 names=[
                     m.ZeroOrMore(),
                     m.ImportAlias(name=m.Name("annotations")),
                     m.ZeroOrMore(),
                 ],
             ),
     ):
         self.has_future_annotations_import = True
コード例 #5
0
ファイル: test_replace.py プロジェクト: shannonzhu/LibCST
    def test_replace_sequence_extract(self) -> None:
        def _reverse_params(
            node: cst.CSTNode,
            extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]],
        ) -> cst.CSTNode:
            return cst.ensure_type(node, cst.FunctionDef).with_changes(
                # pyre-ignore We know "params" is a Sequence[Parameters] but asserting that
                # to pyre is difficult.
                params=cst.Parameters(
                    params=list(reversed(extraction["params"]))))

        # Verify that we can still extract sequences with replace.
        original = cst.parse_module(
            "def bar(baz: int, foo: int, ) -> int:\n    return baz + foo\n")
        replaced = cst.ensure_type(
            m.replace(
                original,
                m.FunctionDef(params=m.Parameters(params=m.SaveMatchedNode(
                    [m.ZeroOrMore(m.Param())], "params"))),
                _reverse_params,
            ),
            cst.Module,
        ).code
        self.assertEqual(
            replaced,
            "def bar(foo: int, baz: int, ) -> int:\n    return baz + foo\n")
コード例 #6
0
ファイル: wxpython.py プロジェクト: expobrain/python-codemods
class MakeModalCommand(VisitorBasedCodemodCommand):

    DESCRIPTION: str = "Replace built-in method MAkeModal with helper"

    method_matcher = matchers.FunctionDef(
        name=matchers.Name(value="MakeModal"),
        params=matchers.Parameters(params=[
            matchers.Param(name=matchers.Name(value="self")),
            matchers.ZeroOrMore()
        ]),
    )
    call_matcher = matchers.Call(
        func=matchers.Attribute(value=matchers.Name(value="self"),
                                attr=matchers.Name(value="MakeModal")))

    method_cst = cst.parse_statement(
        textwrap.dedent("""
            def MakeModal(self, modal=True):
                if modal and not hasattr(self, '_disabler'):
                    self._disabler = wx.WindowDisabler(self)
                if not modal and hasattr(self, '_disabler'):
                    del self._disabler
            """))

    def __init__(self, context: CodemodContext):
        super().__init__(context)

        self.stack: List[cst.ClassDef] = []

    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        self.stack.append(node)

    def leave_ClassDef(self, original_node: cst.ClassDef,
                       updated_node: cst.ClassDef) -> cst.ClassDef:
        return self.stack.pop()

    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.Call:
        if matchers.matches(updated_node, self.call_matcher):
            # Search for MakeModal() method
            current_class = self.stack[-1]
            has_make_modal_method = False

            for method in current_class.body.body:
                if matchers.matches(method, self.method_matcher):
                    has_make_modal_method = True

            # If not, add it to the current class
            if not has_make_modal_method:
                current_class = current_class.with_changes(
                    body=current_class.body.with_changes(
                        body=[*current_class.body.body, self.method_cst]))

                self.stack[-1] = current_class

        return updated_node
コード例 #7
0
        class TestVisitor(MatcherDecoratableTransformer):
            def __init__(self) -> None:
                super().__init__()
                self.visits: List[str] = []

            @call_if_inside(
                m.FunctionDef(m.Name("foo"),
                              params=m.Parameters([m.ZeroOrMore()])))
            def visit_SimpleString(self, node: cst.SimpleString) -> None:
                self.visits.append(node.value)
コード例 #8
0
 def visit_Module(self, node: cst.Module) -> None:
     if self.rule_disabled:
         return
     if not m.matches(
             node, m.Module(header=[*self.header_matcher,
                                    m.ZeroOrMore()])):
         self.report(
             node,
             replacement=node.with_changes(
                 header=[*self.header_replacement, *node.header]),
         )
コード例 #9
0
ファイル: test_extract.py プロジェクト: annieliu10/Euphoria
 def test_extract_sequence_multiple_wildcards(self) -> None:
     expression = cst.parse_expression("1, 2, 3, 4")
     nodes = m.extract(
         expression,
         m.Tuple(elements=(
             m.SaveMatchedNode(m.ZeroOrMore(), "head"),
             m.SaveMatchedNode(m.Element(value=m.Integer(
                 value="3")), "element"),
             m.SaveMatchedNode(m.ZeroOrMore(), "tail"),
         )),
     )
     tuple_elements = cst.ensure_type(expression, cst.Tuple).elements
     self.assertEqual(
         nodes,
         {
             "head": tuple(tuple_elements[:2]),
             "element": tuple_elements[2],
             "tail": tuple(tuple_elements[3:]),
         },
     )
コード例 #10
0
 def _has_testnode(node: cst.Module) -> bool:
     return m.matches(
         node,
         m.Module(body=[
             # Sequence wildcard matchers matches LibCAST nodes in a row in a
             # sequence. It does not implicitly match on partial sequences. So,
             # when matching against a sequence we will need to provide a
             # complete pattern. This often means using helpers such as
             # ``ZeroOrMore()`` as the first and last element of the sequence.
             m.ZeroOrMore(),
             m.AtLeastN(
                 n=1,
                 matcher=m.OneOf(
                     m.FunctionDef(name=m.Name(value=m.MatchIfTrue(
                         lambda value: value.startswith("test_")))),
                     m.ClassDef(name=m.Name(value=m.MatchIfTrue(
                         lambda value: value.startswith("Test")))),
                 ),
             ),
             m.ZeroOrMore(),
         ]),
     )
コード例 #11
0
 def test_zero_or_more_matcher_args_false(self) -> None:
     # Fail to match a function call to "foo" where the first argument is the
     # integer value 1, and the rest of the arguments are strings.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.SimpleString()))),
             ),
         )
     )
     # Fail to match a function call to "foo" where the first argument is the
     # integer value 1, and the rest of the arguements are integers with the
     # value 2.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.Integer("2")))),
             ),
         )
     )
コード例 #12
0
ファイル: test_extract.py プロジェクト: annieliu10/Euphoria
 def test_extract_sequence(self) -> None:
     expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)")
     nodes = m.extract(
         expression,
         m.Tuple(elements=[
             m.DoNotCare(),
             m.Element(
                 m.Call(args=m.SaveMatchedNode([m.ZeroOrMore()], "args"))),
         ]),
     )
     extracted_seq = cst.ensure_type(
         cst.ensure_type(expression, cst.Tuple).elements[1].value,
         cst.Call).args
     self.assertEqual(nodes, {"args": extracted_seq})
コード例 #13
0
ファイル: test_extract.py プロジェクト: annieliu10/Euphoria
 def test_extract_optional_wildcard(self) -> None:
     expression = cst.parse_expression("a + b[c], d(e, f * g)")
     nodes = m.extract(
         expression,
         m.Tuple(elements=[
             m.DoNotCare(),
             m.Element(
                 m.Call(args=[
                     m.ZeroOrMore(),
                     m.ZeroOrOne(
                         m.Arg(m.SaveMatchedNode(m.Attribute(), "arg"))),
                 ])),
         ]),
     )
     self.assertEqual(nodes, {})
コード例 #14
0
    def _split_module(
        self, orig_module: libcst.Module, updated_module: libcst.Module
    ) -> Tuple[List[Union[libcst.SimpleStatementLine,
                          libcst.BaseCompoundStatement]],
               List[Union[libcst.SimpleStatementLine,
                          libcst.BaseCompoundStatement]], List[Union[
                              libcst.SimpleStatementLine,
                              libcst.BaseCompoundStatement]], ]:
        statement_before_import_location = 0
        import_add_location = 0

        # never insert an import before initial __strict__ flag
        if m.matches(
                orig_module,
                m.Module(body=[
                    m.SimpleStatementLine(body=[
                        m.Assign(targets=[
                            m.AssignTarget(target=m.Name("__strict__"))
                        ])
                    ]),
                    m.ZeroOrMore(),
                ]),
        ):
            statement_before_import_location = import_add_location = 1

        # This works under the principle that while we might modify node contents,
        # we have yet to modify the number of statements. So we can match on the
        # original tree but break up the statements of the modified tree. If we
        # change this assumption in this visitor, we will have to change this code.
        for i, statement in enumerate(orig_module.body):
            if m.matches(
                    statement,
                    m.SimpleStatementLine(
                        body=[m.Expr(value=m.SimpleString())])):
                statement_before_import_location = import_add_location = 1
            elif isinstance(statement, libcst.SimpleStatementLine):
                for possible_import in statement.body:
                    for last_import in self.all_imports:
                        if possible_import is last_import:
                            import_add_location = i + 1
                            break

        return (
            list(updated_module.body[:statement_before_import_location]),
            list(updated_module.
                 body[statement_before_import_location:import_add_location]),
            list(updated_module.body[import_add_location:]),
        )
コード例 #15
0
 def visit_Call(self, node: cst.Call) -> None:
     # print(node)
     d = m.extract(
         node,
         m.Call(
             func=m.OneOf(m.Name("Extension"), m.Name("addMacExtension")),
             args=(
                 m.Arg(value=m.SaveMatchedNode(m.SimpleString(),
                                               "extension_name")),
                 m.ZeroOrMore(m.DoNotCare()),
             ),
         ),
     )
     if d:
         assert isinstance(d["extension_name"], cst.SimpleString)
         self.extension_names.append(d["extension_name"].evaluated_value)
コード例 #16
0
ファイル: test_extract.py プロジェクト: annieliu10/Euphoria
 def test_extract_precedence_sequence_wildcard(self) -> None:
     expression = cst.parse_expression("a + b[c], d(e, f * g)")
     nodes = m.extract(
         expression,
         m.Tuple(elements=[
             m.DoNotCare(),
             m.Element(
                 m.Call(args=[
                     m.ZeroOrMore(
                         m.Arg(m.SaveMatchedNode(m.DoNotCare(), "arg")))
                 ])),
         ]),
     )
     extracted_node = (cst.ensure_type(
         cst.ensure_type(expression, cst.Tuple).elements[1].value,
         cst.Call).args[1].value)
     self.assertEqual(nodes, {"arg": extracted_node})
コード例 #17
0
ファイル: codemods.py プロジェクト: jjpal/hypothesis
    def leave_Call(self, original_node, updated_node):
        """Convert positional to keyword arguments."""
        metadata = self.get_metadata(cst.metadata.QualifiedNameProvider,
                                     original_node)
        qualnames = {qn.name for qn in metadata}

        # If this isn't one of our known functions, or it has no posargs, stop there.
        if (len(qualnames) != 1
                or not qualnames.intersection(self.kwonly_functions)
                or not m.matches(
                    updated_node,
                    m.Call(
                        func=m.DoesNotMatch(m.Call()),
                        args=[m.Arg(keyword=None),
                              m.ZeroOrMore()],
                    ),
                )):
            return updated_node

        # Get the actual function object so that we can inspect the signature.
        # This does e.g. incur a dependency on Numpy to fix Numpy-dependent code,
        # but having a single source of truth about the signatures is worth it.
        params = signature(get_fn(*qualnames)).parameters.values()

        # st.floats() has a new allow_subnormal kwonly argument not at the end,
        # so we do a bit more of a dance here.
        if qualnames == {"hypothesis.strategies.floats"}:
            params = [p for p in params if p.name != "allow_subnormal"]

        if len(updated_node.args) > len(params):
            return updated_node

        # Create new arg nodes with the newly required keywords
        assign_nospace = cst.AssignEqual(
            whitespace_before=cst.SimpleWhitespace(""),
            whitespace_after=cst.SimpleWhitespace(""),
        )
        newargs = [
            arg if arg.keyword or arg.star
            or p.kind is not Parameter.KEYWORD_ONLY else arg.with_changes(
                keyword=cst.Name(p.name), equal=assign_nospace)
            for p, arg in zip(params, updated_node.args)
        ]
        return updated_node.with_changes(args=newargs)
コード例 #18
0
ファイル: codemods.py プロジェクト: vlulla/hypothesis
    def leave_Call(self, original_node, updated_node):
        """Convert positional to keyword arguments."""
        metadata = self.get_metadata(cst.metadata.QualifiedNameProvider, original_node)
        qualnames = {qn.name for qn in metadata}

        # If this isn't one of our known functions, or it has no posargs, stop there.
        if (
            len(qualnames) != 1
            or not qualnames.intersection(self.kwonly_functions)
            or not m.matches(
                updated_node,
                m.Call(
                    func=m.DoesNotMatch(m.Call()),
                    args=[m.Arg(keyword=None), m.ZeroOrMore()],
                ),
            )
        ):
            return updated_node

        # Get the actual function object so that we can inspect the signature.
        # This does e.g. incur a dependency on Numpy to fix Numpy-dependent code,
        # but having a single source of truth about the signatures is worth it.
        mod, fn = list(qualnames.intersection(self.kwonly_functions))[0].rsplit(".", 1)
        try:
            func = getattr(importlib.import_module(mod), fn)
        except ImportError:
            return updated_node

        # Create new arg nodes with the newly required keywords
        assign_nospace = cst.AssignEqual(
            whitespace_before=cst.SimpleWhitespace(""),
            whitespace_after=cst.SimpleWhitespace(""),
        )
        newargs = [
            arg
            if arg.keyword or arg.star or p.kind is not Parameter.KEYWORD_ONLY
            else arg.with_changes(keyword=cst.Name(p.name), equal=assign_nospace)
            for p, arg in zip(signature(func).parameters.values(), updated_node.args)
        ]
        return updated_node.with_changes(args=newargs)
コード例 #19
0
ファイル: cst_visitor.py プロジェクト: saltudelft/libsa4py
    def __extract_assign_newtype(self, node: cst.Assign):
        """
        Attempts extracting a NewType declaration from the provided Assign node.

        If the Assign node corresponds to a NewType assignment, the NewType name is
        added to the class definitions of the Visitor.
        """
        # Define matcher to extract NewType assignment
        matcher_newtype = match.Assign(
            targets=[  # Check the assign targets
                match.AssignTarget(  # There should only be one target
                    target=match.Name(  # Check target name
                        value=match.SaveMatchedNode(  # Save target name
                            match.MatchRegex(
                                r'(.)+'),  # Match any string literal
                            "type")))
            ],
            value=match.Call(  # We are examining a function call
                func=match.Name(  # Function must have a name
                    value="NewType"  # Name must be 'NewType'
                ),
                args=[
                    match.Arg(  # Check first argument
                        value=match.SimpleString(
                        )  # First argument must be the name for the type
                    ),
                    match.ZeroOrMore(
                    )  # We allow any number of arguments after by def. of NewType
                ]))

        extracted_type = match.extract(node, matcher_newtype)

        if extracted_type is not None:
            # Append the additional type to the list
            # TODO: Either rename class defs, or create new list for additional types
            self.class_defs.append(extracted_type["type"].strip("\'"))
コード例 #20
0
gen_return_call_with_args_matcher = m.Raise(exc=m.Call(
    func=some_version_of("tornado.gen.Return"), args=[m.AtLeastN(n=1)]))
gen_return_call_matcher = m.Raise(exc=m.Call(
    func=some_version_of("tornado.gen.Return")))
gen_return_matcher = gen_return_statement_matcher | gen_return_call_matcher
gen_sleep_matcher = m.Call(func=some_version_of("gen.sleep"))
gen_task_matcher = m.Call(func=some_version_of("gen.Task"))
gen_coroutine_decorator_matcher = m.Decorator(
    decorator=some_version_of("tornado.gen.coroutine"))
gen_test_coroutine_decorator = m.Decorator(
    decorator=some_version_of("tornado.testing.gen_test"))
coroutine_decorator_matcher = (gen_coroutine_decorator_matcher
                               | gen_test_coroutine_decorator)
coroutine_matcher = m.FunctionDef(
    asynchronous=None,
    decorators=[m.ZeroOrMore(), coroutine_decorator_matcher,
                m.ZeroOrMore()],
)


class TransformError(Exception):
    """
    Error raised upon encountering a known error while attempting to transform
    the tree.
    """


class TornadoAsyncTransformer(cst.CSTTransformer):
    """
    A libcst transformer that replaces the legacy @gen.coroutine/yield
    async syntax with the python3.7 native async/await syntax.

def is_valid_comment(comment_text: str) -> bool:
    return any(item in comment_text for item in VALID_COMMENTS_FOR_NULL_TRUE)


null_comment = m.TrailingWhitespace(
    comment=m.Comment(m.MatchIfTrue(is_valid_comment)),
    newline=m.Newline(),
)

field_without_comment = m.SimpleStatementLine(
    body=[
        m.Assign(value=(m.Call(
            args=[
                m.ZeroOrMore(),
                m.Arg(keyword=m.Name('null'), value=m.Name('True')),
                m.ZeroOrMore(),
            ],
            whitespace_before_args=m.DoesNotMatch(
                m.ParenthesizedWhitespace(null_comment)),
        )
                        | m.Call(
                            func=m.Attribute(attr=m.Name('NullBooleanField')),
                            whitespace_before_args=m.DoesNotMatch(
                                m.ParenthesizedWhitespace(null_comment)),
                        )
                        | m.Call(
                            func=m.Name('NullBooleanField'),
                            whitespace_before_args=m.DoesNotMatch(
                                m.ParenthesizedWhitespace(null_comment)),
コード例 #22
0
ファイル: _codemods.py プロジェクト: Zac-HD/shed
class ShedFixers(VisitorBasedCodemodCommand):
    """Fix a variety of small problems.

    Replaces `raise NotImplemented` with `raise NotImplementedError`,
    and converts always-failing assert statements to explicit `raise` statements.

    Also includes code closely modelled on pybetter's fixers, because it's
    considerably faster to run all transforms in a single pass if possible.
    """

    DESCRIPTION = "Fix a variety of style, performance, and correctness issues."

    @m.call_if_inside(m.Raise(exc=m.Name(value="NotImplemented")))
    def leave_Name(self, _, updated_node):  # noqa
        return updated_node.with_changes(value="NotImplementedError")

    def leave_Assert(self, _, updated_node):  # noqa
        test_code = cst.Module("").code_for_node(updated_node.test)
        try:
            test_literal = literal_eval(test_code)
        except Exception:
            return updated_node
        if test_literal:
            return cst.RemovalSentinel.REMOVE
        if updated_node.msg is None:
            return cst.Raise(cst.Name("AssertionError"))
        return cst.Raise(
            cst.Call(cst.Name("AssertionError"),
                     args=[cst.Arg(updated_node.msg)]))

    @m.leave(
        m.ComparisonTarget(comparator=oneof_names("None", "False", "True"),
                           operator=m.Equal()))
    def convert_none_cmp(self, _, updated_node):
        """Inspired by Pybetter."""
        return updated_node.with_changes(operator=cst.Is())

    @m.leave(
        m.UnaryOperation(
            operator=m.Not(),
            expression=m.Comparison(
                comparisons=[m.ComparisonTarget(operator=m.In())]),
        ))
    def replace_not_in_condition(self, _, updated_node):
        """Also inspired by Pybetter."""
        expr = cst.ensure_type(updated_node.expression, cst.Comparison)
        return cst.Comparison(
            left=expr.left,
            lpar=updated_node.lpar,
            rpar=updated_node.rpar,
            comparisons=[
                expr.comparisons[0].with_changes(operator=cst.NotIn())
            ],
        )

    @m.leave(
        m.Call(
            lpar=[m.AtLeastN(n=1, matcher=m.LeftParen())],
            rpar=[m.AtLeastN(n=1, matcher=m.RightParen())],
        ))
    def remove_pointless_parens_around_call(self, _, updated_node):
        # This is *probably* valid, but we might have e.g. a multi-line parenthesised
        # chain of attribute accesses ("fluent interface"), where we need the parens.
        noparens = updated_node.with_changes(lpar=[], rpar=[])
        try:
            compile(self.module.code_for_node(noparens), "<string>", "eval")
            return noparens
        except SyntaxError:
            return updated_node

    # The following methods fix https://pypi.org/project/flake8-comprehensions/

    @m.leave(m.Call(func=m.Name("list"), args=[m.Arg(m.GeneratorExp())]))
    def replace_generator_in_call_with_comprehension(self, _, updated_node):
        """Fix flake8-comprehensions C400-402 and 403-404.

        C400-402: Unnecessary generator - rewrite as a <list/set/dict> comprehension.
        Note that set and dict conversions are handled by pyupgrade!
        """
        return cst.ListComp(elt=updated_node.args[0].value.elt,
                            for_in=updated_node.args[0].value.for_in)

    @m.leave(
        m.Call(func=m.Name("list"), args=[m.Arg(m.ListComp(), star="")])
        | m.Call(func=m.Name("set"), args=[m.Arg(m.SetComp(), star="")])
        | m.Call(
            func=m.Name("list"),
            args=[m.Arg(m.Call(func=oneof_names("sorted", "list")), star="")],
        ))
    def replace_unnecessary_list_around_sorted(self, _, updated_node):
        """Fix flake8-comprehensions C411 and C413.

        Unnecessary <list/reversed> call around sorted().

        Also covers C411 Unnecessary list call around list comprehension
        for lists and sets.
        """
        return updated_node.args[0].value

    @m.leave(
        m.Call(
            func=m.Name("reversed"),
            args=[m.Arg(m.Call(func=m.Name("sorted")), star="")],
        ))
    def replace_unnecessary_reversed_around_sorted(self, _, updated_node):
        """Fix flake8-comprehensions C413.

        Unnecessary reversed call around sorted().
        """
        call = updated_node.args[0].value
        args = list(call.args)
        for i, arg in enumerate(args):
            if m.matches(arg.keyword, m.Name("reverse")):
                try:
                    val = bool(
                        literal_eval(self.module.code_for_node(arg.value)))
                except Exception:
                    args[i] = arg.with_changes(
                        value=cst.UnaryOperation(cst.Not(), arg.value))
                else:
                    if not val:
                        args[i] = arg.with_changes(value=cst.Name("True"))
                    else:
                        del args[i]
                        args[i - 1] = remove_trailing_comma(args[i - 1])
                break
        else:
            args.append(
                cst.Arg(keyword=cst.Name("reverse"), value=cst.Name("True")))
        return call.with_changes(args=args)

    _sets = oneof_names("set", "frozenset")
    _seqs = oneof_names("list", "reversed", "sorted", "tuple")

    @m.leave(
        m.Call(func=_sets, args=[m.Arg(m.Call(func=_sets | _seqs), star="")])
        | m.Call(
            func=oneof_names("list", "tuple"),
            args=[m.Arg(m.Call(func=oneof_names("list", "tuple")), star="")],
        )
        | m.Call(
            func=m.Name("sorted"),
            args=[m.Arg(m.Call(func=_seqs), star=""),
                  m.ZeroOrMore()],
        ))
    def replace_unnecessary_nested_calls(self, _, updated_node):
        """Fix flake8-comprehensions C414.

        Unnecessary <list/reversed/sorted/tuple> call within <list/set/sorted/tuple>()..
        """
        return updated_node.with_changes(
            args=[cst.Arg(updated_node.args[0].value.args[0].value)] +
            list(updated_node.args[1:]), )

    @m.leave(
        m.Call(
            func=oneof_names("reversed", "set", "sorted"),
            args=[
                m.Arg(m.Subscript(slice=[m.SubscriptElement(ALL_ELEMS_SLICE)]))
            ],
        ))
    def replace_unnecessary_subscript_reversal(self, _, updated_node):
        """Fix flake8-comprehensions C415.

        Unnecessary subscript reversal of iterable within <reversed/set/sorted>().
        """
        return updated_node.with_changes(
            args=[cst.Arg(updated_node.args[0].value.value)], )

    @m.leave(
        multi(
            m.ListComp,
            m.SetComp,
            elt=m.Name(),
            for_in=m.CompFor(target=m.Name(),
                             ifs=[],
                             inner_for_in=None,
                             asynchronous=None),
        ))
    def replace_unnecessary_listcomp_or_setcomp(self, _, updated_node):
        """Fix flake8-comprehensions C416.

        Unnecessary <list/set> comprehension - rewrite using <list/set>().
        """
        if updated_node.elt.value == updated_node.for_in.target.value:
            func = cst.Name(
                "list" if isinstance(updated_node, cst.ListComp) else "set")
            return cst.Call(func=func,
                            args=[cst.Arg(updated_node.for_in.iter)])
        return updated_node

    @m.leave(m.Subscript(oneof_names("Union", "Literal")))
    def reorder_union_literal_contents_none_last(self, _, updated_node):
        subscript = list(updated_node.slice)
        try:
            subscript.sort(key=lambda elt: elt.slice.value.value == "None")
            subscript[-1] = remove_trailing_comma(subscript[-1])
            return updated_node.with_changes(slice=subscript)
        except Exception:  # Single-element literals are not slices, etc.
            return updated_node

    @m.call_if_inside(m.Annotation(annotation=m.BinaryOperation()))
    @m.leave(
        m.BinaryOperation(
            left=m.Name("None") | m.BinaryOperation(),
            operator=m.BitOr(),
            right=m.DoNotCare(),
        ))
    def reorder_union_operator_contents_none_last(self, _, updated_node):
        def _has_none(node):
            if m.matches(node, m.Name("None")):
                return True
            elif m.matches(node, m.BinaryOperation()):
                return _has_none(node.left) or _has_none(node.right)
            else:
                return False

        node_left = updated_node.left
        if _has_none(node_left):
            return updated_node.with_changes(left=updated_node.right,
                                             right=node_left)
        else:
            return updated_node

    @m.leave(m.Subscript(value=m.Name("Literal")))
    def flatten_literal_subscript(self, _, updated_node):
        new_slice = []
        for item in updated_node.slice:
            if m.matches(item.slice.value, m.Subscript(m.Name("Literal"))):
                new_slice += item.slice.value.slice
            else:
                new_slice.append(item)
        return updated_node.with_changes(slice=new_slice)

    @m.leave(m.Subscript(value=m.Name("Union")))
    def flatten_union_subscript(self, _, updated_node):
        new_slice = []
        has_none = False
        for item in updated_node.slice:
            if m.matches(item.slice.value, m.Subscript(m.Name("Optional"))):
                new_slice += item.slice.value.slice  # peel off "Optional"
                has_none = True
            elif m.matches(item.slice.value,
                           m.Subscript(m.Name("Union"))) and m.matches(
                               updated_node.value, item.slice.value.value):
                new_slice += item.slice.value.slice  # peel off "Union" or "Literal"
            elif m.matches(item.slice.value, m.Name("None")):
                has_none = True
            else:
                new_slice.append(item)
        if has_none:
            new_slice.append(
                cst.SubscriptElement(slice=cst.Index(cst.Name("None"))))
        return updated_node.with_changes(slice=new_slice)

    @m.leave(m.Else(m.IndentedBlock([m.SimpleStatementLine([m.Pass()])])))
    def discard_empty_else_blocks(self, _, updated_node):
        # An `else: pass` block can always simply be discarded, and libcst ensures
        # that an Else node can only ever occur attached to an If, While, For, or Try
        # node; in each case `None` is the valid way to represent "no else block".
        if m.findall(updated_node, m.Comment()):
            return updated_node  # If there are any comments, keep the node
        return cst.RemoveFromParent()

    @m.leave(
        m.Lambda(params=m.MatchIfTrue(lambda node: (
            node.star_kwarg is None and not node.kwonly_params and not node.
            posonly_params and isinstance(node.star_arg, cst.MaybeSentinel) and
            all(param.default is None for param in node.params)))))
    def remove_lambda_indirection(self, _, updated_node):
        same_args = [
            m.Arg(m.Name(param.name.value), star="", keyword=None)
            for param in updated_node.params.params
        ]
        if m.matches(updated_node.body, m.Call(args=same_args)):
            return cst.ensure_type(updated_node.body, cst.Call).func
        return updated_node

    @m.leave(
        m.BooleanOperation(
            left=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]),
            operator=m.Or(),
            right=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]),
        ))
    def collapse_isinstance_checks(self, _, updated_node):
        left_target, left_type = updated_node.left.args
        right_target, right_type = updated_node.right.args
        if left_target.deep_equals(right_target):
            merged_type = cst.Arg(
                cst.Tuple([
                    cst.Element(left_type.value),
                    cst.Element(right_type.value)
                ]))
            return updated_node.left.with_changes(
                args=[left_target, merged_type])
        return updated_node
コード例 #23
0
 def test_zero_or_more_matcher_args_true(self) -> None:
     # Match a function call to "foo" where the first argument is the integer
     # value 1, and the rest of the arguements are wildcards.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg())),
             ),
         )
     )
     # Match a function call to "foo" where the first argument is the integer
     # value 1, and the rest of the arguements are integers of any value.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.Integer()))),
             ),
         )
     )
     # Match a function call to "foo" with zero or more arguments, where the
     # first argument can optionally be the integer 1 or 2, and the second
     # can only be the integer 2. This case verifies non-greedy behavior in the
     # matcher.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.ZeroOrMore(m.Arg(m.OneOf(m.Integer("1"), m.Integer("2")))),
                     m.Arg(m.Integer("2")),
                     m.ZeroOrMore(),
                 ),
             ),
         )
     )
     # Match a function call to "foo" where the first argument is the integer
     # value 1, and the rest of the arguements are integers with the value
     # 2 or 3.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.Arg(m.Integer("1")),
                     m.ZeroOrMore(m.Arg(m.OneOf(m.Integer("2"), m.Integer("3")))),
                 ),
             ),
         )
     )
コード例 #24
0
 def test_zero_or_more_matcher_no_args_true(self) -> None:
     # Match a function call to "foo" with any number of arguments as
     # long as the first one is an integer with the value 1.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrMore())
             ),
         )
     )
     # Match a function call to "foo" with any number of arguments as
     # long as one of them is an integer with the value 1.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.ZeroOrMore(), m.Arg(m.Integer("1")), m.ZeroOrMore()),
             ),
         )
     )
     # Match a function call to "foo" with any number of arguments as
     # long as one of them is an integer with the value 2.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.ZeroOrMore(), m.Arg(m.Integer("2")), m.ZeroOrMore()),
             ),
         )
     )
     # Match a function call to "foo" with any number of arguments as
     # long as one of them is an integer with the value 3.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.ZeroOrMore(), m.Arg(m.Integer("3")), m.ZeroOrMore()),
             ),
         )
     )
     # Match a function call to "foo" with any number of arguments as
     # long as the last one is an integer with the value 3.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.ZeroOrMore(), m.Arg(m.Integer("3")))
             ),
         )
     )
     # Match a function call to "foo" with any number of arguments as
     # long as there are two arguments with the values 1 and 3 anywhere
     # in the argument list, respecting order.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.ZeroOrMore(),
                     m.Arg(m.Integer("1")),
                     m.ZeroOrMore(),
                     m.Arg(m.Integer("3")),
                     m.ZeroOrMore(),
                 ),
             ),
         )
     )
     # Match a function call to "foo" with any number of arguments as
     # long as there are three arguments with the values 1, 2 and 3 anywhere
     # in the argument list, respecting order.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.ZeroOrMore(),
                     m.Arg(m.Integer("1")),
                     m.ZeroOrMore(),
                     m.Arg(m.Integer("2")),
                     m.ZeroOrMore(),
                     m.Arg(m.Integer("3")),
                     m.ZeroOrMore(),
                 ),
             ),
         )
     )
コード例 #25
0
    def visit_Call(self, node: cst.Call) -> None:
        if m.matches(
                node,
                m.Call(
                    func=m.Name("tuple") | m.Name("list") | m.Name("set")
                    | m.Name("dict"),
                    args=[m.Arg(value=m.List() | m.Tuple())],
                ),
        ) or m.matches(
                node,
                m.Call(func=m.Name("tuple") | m.Name("list") | m.Name("dict"),
                       args=[]),
        ):

            pairs_matcher = m.ZeroOrMore(
                m.Element(m.Tuple(
                    elements=[m.DoNotCare(), m.DoNotCare()]))
                | m.Element(m.List(
                    elements=[m.DoNotCare(), m.DoNotCare()])))

            exp = cst.ensure_type(node, cst.Call)
            call_name = cst.ensure_type(exp.func, cst.Name).value

            # If this is a empty call, it's an Unnecessary Call where we rewrite the call
            # to literal, except set().
            if not exp.args:
                elements = []
                message_formatter = UNNCESSARY_CALL
            else:
                arg = exp.args[0].value
                elements = cst.ensure_type(
                    arg, cst.List
                    if isinstance(arg, cst.List) else cst.Tuple).elements
                message_formatter = UNNECESSARY_LITERAL

            if call_name == "tuple":
                new_node = cst.Tuple(elements=elements)
            elif call_name == "list":
                new_node = cst.List(elements=elements)
            elif call_name == "set":
                # set() doesn't have an equivelant literal call. If it was
                # matched here, it's an unnecessary literal suggestion.
                if len(elements) == 0:
                    self.report(
                        node,
                        UNNECESSARY_LITERAL.format(func=call_name),
                        replacement=node.deep_replace(
                            node, cst.Call(func=cst.Name("set"))),
                    )
                    return
                new_node = cst.Set(elements=elements)
            elif len(elements) == 0 or m.matches(
                    exp.args[0].value,
                    m.Tuple(elements=[pairs_matcher])
                    | m.List(elements=[pairs_matcher]),
            ):
                new_node = cst.Dict(elements=[(
                    lambda val: cst.DictElement(val.elements[
                        0].value, val.elements[1].value))(cst.ensure_type(
                            ele.value,
                            cst.Tuple if isinstance(ele.value, cst.Tuple
                                                    ) else cst.List,
                        )) for ele in elements])
            else:
                # Unrecoginized form
                return

            self.report(
                node,
                message_formatter.format(func=call_name),
                replacement=node.deep_replace(node, new_node),
            )