Пример #1
0
 def test_does_not_match_operator_true(self) -> None:
     # Match on any call that takes one argument that isn't the value None.
     self.assertTrue(
         matches(
             cst.Call(func=cst.Name("foo"),
                      args=(cst.Arg(cst.Name("True")), )),
             m.Call(args=(m.Arg(value=~(m.Name("None"))), )),
         ))
     self.assertTrue(
         matches(
             cst.Call(func=cst.Name("foo"),
                      args=(cst.Arg(cst.Integer("1")), )),
             m.Call(args=(~(m.Arg(m.Name("None"))), )),
         ))
     # Match any call that takes an argument which isn't True or False.
     self.assertTrue(
         matches(
             cst.Call(func=cst.Name("foo"),
                      args=(cst.Arg(cst.Integer("1")), )),
             m.Call(args=(m.Arg(
                 value=~(m.Name("True") | m.Name("False"))), )),
         ))
     self.assertTrue(
         matches(
             cst.Call(func=cst.Name("foo"),
                      args=(cst.Arg(cst.Name("None")), )),
             m.Call(args=(m.Arg(value=(~(m.Name("True")))
                                & (~(m.Name("False")))), )),
         ))
     # Roundabout way to verify that or operator works with inverted nodes.
     self.assertTrue(
         matches(
             cst.Call(func=cst.Name("foo"),
                      args=(cst.Arg(cst.Name("False")), )),
             m.Call(args=(m.Arg(value=(~(m.Name("True")))
                                | (~(m.Name("True")))), )),
         ))
     # Roundabout way to verify that inverse operator works properly on AllOf.
     self.assertTrue(
         matches(
             cst.Call(func=cst.Name("foo"),
                      args=(cst.Arg(cst.Integer("1")), )),
             m.Call(args=(m.Arg(value=~(m.Name() & m.Name("True"))), )),
         ))
     # Match any name node that doesn't match the regex for True
     self.assertTrue(
         matches(cst.Name("False"), m.Name(value=~(m.MatchRegex(r"True")))))
Пример #2
0
    def visit_Lambda(self, node: cst.Lambda) -> None:
        if m.matches(
                node,
                m.Lambda(
                    params=m.MatchIfTrue(self._is_simple_parameter_spec),
                    body=m.Call(args=[
                        m.Arg(value=m.Name(value=param.name.value),
                              star="",
                              keyword=None) for param in node.params.params
                    ]),
                ),
        ):
            call = cst.ensure_type(node.body, cst.Call)
            full_name = get_full_name_for_node(call)
            if full_name is None:
                full_name = "function"

            self.report(
                node,
                UNNECESSARY_LAMBDA.format(function=full_name),
                replacement=call.func,
            )
Пример #3
0
 def test_or_matcher_true(self) -> None:
     # Match on either True or False identifier.
     self.assertTrue(
         matches(libcst.Name("True"),
                 m.OneOf(m.Name("True"), m.Name("False"))))
     # Match any assignment that assigns a value of True or False to an
     # unspecified target.
     self.assertTrue(
         matches(
             libcst.Assign((libcst.AssignTarget(libcst.Name("x")), ),
                           libcst.Name("True")),
             m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))),
         ))
     self.assertTrue(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (
                     libcst.Arg(libcst.Integer("1")),
                     libcst.Arg(libcst.Integer("2")),
                     libcst.Arg(libcst.Integer("3")),
                 ),
             ),
             m.Call(
                 m.Name("foo"),
                 m.OneOf(
                     (
                         m.Arg(m.Integer("3")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("1")),
                     ),
                     (
                         m.Arg(m.Integer("1")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("3")),
                     ),
                 ),
             ),
         ))
Пример #4
0
 def test_extract_multiple(self) -> None:
     expression = cst.parse_expression("a + b[c], d(e, f * g)")
     nodes = m.extract(
         expression,
         m.Tuple(elements=[
             m.Element(
                 m.BinaryOperation(
                     left=m.SaveMatchedNode(m.Name(), "left"))),
             m.Element(m.Call(func=m.SaveMatchedNode(m.Name(), "func"))),
         ]),
     )
     extracted_node_left = cst.ensure_type(
         cst.ensure_type(expression, cst.Tuple).elements[0].value,
         cst.BinaryOperation,
     ).left
     extracted_node_func = cst.ensure_type(
         cst.ensure_type(expression, cst.Tuple).elements[1].value,
         cst.Call).func
     self.assertEqual(nodes, {
         "left": extracted_node_left,
         "func": extracted_node_func
     })
Пример #5
0
 def test_or_matcher_false(self) -> None:
     # Fail to match since None is not True or False.
     self.assertFalse(
         matches(libcst.Name("None"),
                 m.OneOf(m.Name("True"), m.Name("False"))))
     # Fail to match since assigning None to a target is not the same as
     # assigning True or False to a target.
     self.assertFalse(
         matches(
             libcst.Assign((libcst.AssignTarget(libcst.Name("x")), ),
                           libcst.Name("None")),
             m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))),
         ))
     self.assertFalse(
         matches(
             libcst.Call(
                 libcst.Name("foo"),
                 (
                     libcst.Arg(libcst.Integer("1")),
                     libcst.Arg(libcst.Integer("2")),
                     libcst.Arg(libcst.Integer("3")),
                 ),
             ),
             m.Call(
                 m.Name("foo"),
                 m.OneOf(
                     (
                         m.Arg(m.Integer("3")),
                         m.Arg(m.Integer("2")),
                         m.Arg(m.Integer("1")),
                     ),
                     (
                         m.Arg(m.Integer("4")),
                         m.Arg(m.Integer("5")),
                         m.Arg(m.Integer("6")),
                     ),
                 ),
             ),
         ))
Пример #6
0
    def collect_targets(self, stack: Tuple[cst.BaseExpression, ...]) -> Tuple[List[cst.BaseExpression], Dict[cst.BaseExpression, List[cst.BaseExpression]]]:
        targets = {}
        operands = []

        for operand in stack:
            if m.matches(operand, m.Call(func=m.DoNotCare(), args=[m.Arg(), m.Arg(~m.Tuple())])):
                call = cst.ensure_type(operand, cst.Call)
                if not QualifiedNameProvider.has_name(self, call, _ISINSTANCE):
                    operands.append(operand)
                    continue

                target, match = call.args[0].value, call.args[1].value
                for possible_target in targets:
                    if target.deep_equals(possible_target):
                        targets[possible_target].append(match)
                        break
                else:
                    operands.append(target)
                    targets[target] = [match]
            else:
                operands.append(operand)

        return operands, targets
Пример #7
0
class ToolbarAddToolCommand(VisitorBasedCodemodCommand):

    DESCRIPTION: str = "Transforms wx.Toolbar.DoAddTool method into AddTool"

    args_map = {"id": "toolId"}
    args_matchers_map = {
        matchers.Arg(keyword=matchers.Name(value=value)): renamed
        for value, renamed in args_map.items()
    }
    call_matcher = matchers.Call(
        func=matchers.Attribute(attr=matchers.Name(value="DoAddTool")),
        args=matchers.MatchIfTrue(lambda args: bool(
            set(arg.keyword.value for arg in args if arg and arg.keyword).
            intersection(ToolbarAddToolCommand.args_map.keys()))),
    )

    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.Call:
        if matchers.matches(updated_node, self.call_matcher):
            # Update method's call
            updated_node = updated_node.with_changes(
                func=updated_node.func.with_changes(attr=cst.Name(
                    value="AddTool")))

            # Transform keywords
            updated_node_args = list(updated_node.args)

            for arg_matcher, renamed in self.args_matchers_map.items():
                for i, node_arg in enumerate(updated_node.args):
                    if matchers.matches(node_arg, arg_matcher):
                        updated_node_args[i] = node_arg.with_changes(
                            keyword=cst.Name(value=renamed))

                updated_node = updated_node.with_changes(
                    args=updated_node_args)

        return updated_node
Пример #8
0
    def __extract_assign_newtype(self, node: cst.Assign):
        """
        Attempts extracting a NewType declaration from the provided Assign node.

        If the Assign node corresponds to a NewType assignment, the NewType name is
        added to the class definitions of the Visitor.
        """
        # Define matcher to extract NewType assignment
        matcher_newtype = match.Assign(
            targets=[  # Check the assign targets
                match.AssignTarget(  # There should only be one target
                    target=match.Name(  # Check target name
                        value=match.SaveMatchedNode(  # Save target name
                            match.MatchRegex(
                                r'(.)+'),  # Match any string literal
                            "type")))
            ],
            value=match.Call(  # We are examining a function call
                func=match.Name(  # Function must have a name
                    value="NewType"  # Name must be 'NewType'
                ),
                args=[
                    match.Arg(  # Check first argument
                        value=match.SimpleString(
                        )  # First argument must be the name for the type
                    ),
                    match.ZeroOrMore(
                    )  # We allow any number of arguments after by def. of NewType
                ]))

        extracted_type = match.extract(node, matcher_newtype)

        if extracted_type is not None:
            # Append the additional type to the list
            # TODO: Either rename class defs, or create new list for additional types
            self.class_defs.append(extracted_type["type"].strip("\'"))
Пример #9
0
    def leave_Attribute(self, original_node: cst.Attribute,
                        updated_node: cst.Attribute):
        self.attribute_stack.pop()

        # x.y. z
        tail = updated_node.value
        head = updated_node.attr

        attrs = split_attribute(tail)

        # Обфускация метода/поля
        if m.matches(head, m.Name()):
            head = cst.ensure_type(head, cst.Name)
            updated_node = self.obf_var(head, updated_node)

        elif m.matches(head, m.Call()):
            head = cst.ensure_type(head, cst.Call)
            updated_node = self.obf_function_name(head, updated_node)

        else:
            pass

        # Обфускация имени
        if m.matches(tail, m.Name()):
            tail = cst.ensure_type(tail, cst.Name)
            if self.can_rename(tail.value, 'v', 'a', 'ca'):
                updated_node = updated_node.with_changes(
                    value=self.get_new_cst_name(tail.value))

        elif m.matches(tail, m.Subscript()):
            tail = cst.ensure_type(tail, cst.Subscript)

        else:
            pass

        return updated_node
Пример #10
0
class HttpRequestXReadLinesTransformer(BaseDjCodemodTransformer):
    """Replace `HttpRequest.xreadlines()` by iterating over the request."""

    deprecated_in = DJANGO_2_0
    removed_in = DJANGO_3_0

    # This should be conservative and only apply changes to:
    # - variables called `request`/`req`
    # - `request`/`req` attributes (e.g `self.request`/`view.req`...)
    matcher = m.Call(func=m.Attribute(
        value=m.OneOf(
            m.Name(value="request"),
            m.Name(value="req"),
            m.Attribute(attr=m.Name(value="request")),
            m.Attribute(attr=m.Name(value="req")),
        ),
        attr=m.Name(value="xreadlines"),
    ))

    def leave_Call(self, original_node: Call,
                   updated_node: Call) -> BaseExpression:
        if m.matches(updated_node, self.matcher):
            return updated_node.func.value
        return super().leave_Call(original_node, updated_node)
Пример #11
0
 def test_at_least_n_matcher_args_false(self) -> None:
     # Fail to match a function call to "foo" where the first argument is the
     # integer value 1, and there are at least two arguments after that are
     # strings.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.Arg(m.Integer("1")),
                     m.AtLeastN(m.Arg(m.SimpleString()), n=2),
                 ),
             ),
         )
     )
     # Fail to match a function call to "foo" where the first argument is the integer
     # value 1, and there are at least three wildcard arguments after.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(), n=3)),
             ),
         )
     )
     # Fail to match a function call to "foo" where the first argument is the
     # integer value 1, and there are at least two arguements that are integers with
     # the value 2 after.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.Arg(m.Integer("1")),
                     m.AtLeastN(m.Arg(m.Integer("2")), n=2),
                 ),
             ),
         )
     )
Пример #12
0
 def test_zero_or_more_matcher_args_true(self) -> None:
     # Match a function call to "foo" where the first argument is the integer
     # value 1, and the rest of the arguements are wildcards.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg())),
             ),
         )
     )
     # Match a function call to "foo" where the first argument is the integer
     # value 1, and the rest of the arguements are integers of any value.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.Integer()))),
             ),
         )
     )
     # Match a function call to "foo" with zero or more arguments, where the
     # first argument can optionally be the integer 1 or 2, and the second
     # can only be the integer 2. This case verifies non-greedy behavior in the
     # matcher.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.ZeroOrMore(m.Arg(m.OneOf(m.Integer("1"), m.Integer("2")))),
                     m.Arg(m.Integer("2")),
                     m.ZeroOrMore(),
                 ),
             ),
         )
     )
     # Match a function call to "foo" where the first argument is the integer
     # value 1, and the rest of the arguements are integers with the value
     # 2 or 3.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.Arg(m.Integer("1")),
                     m.ZeroOrMore(m.Arg(m.OneOf(m.Integer("2"), m.Integer("3")))),
                 ),
             ),
         )
     )
Пример #13
0
 def test_at_least_n_matcher_no_args_true(self) -> None:
     # Match a function call to "foo" with at least one argument.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(func=m.Name("foo"), args=(m.AtLeastN(n=1),)),
         )
     )
     # Match a function call to "foo" with at least two arguments.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(func=m.Name("foo"), args=(m.AtLeastN(n=2),)),
         )
     )
     # Match a function call to "foo" with at least three arguments.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(func=m.Name("foo"), args=(m.AtLeastN(n=3),)),
         )
     )
     # Match a function call to "foo" with at least two arguments the
     # first one being the integer 1.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(n=1))
             ),
         )
     )
     # Match a function call to "foo" with at least three arguments the
     # first one being the integer 1.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(n=2))
             ),
         )
     )
     # Match a function call to "foo" with at least three arguments. The
     # There should be an argument with the value 2, which should have
     # at least one argument before and one argument after.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.AtLeastN(n=1), m.Arg(m.Integer("2")), m.AtLeastN(n=1)),
             ),
         )
     )
     # Match a function call to "foo" with at least two arguments, the last
     # one being the value 3.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.AtLeastN(n=1), m.Arg(m.Integer("3")))
             ),
         )
     )
     # Match a function call to "foo" with at least three arguments, the last
     # one being the value 3.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.AtLeastN(n=2), m.Arg(m.Integer("3")))
             ),
         )
     )
Пример #14
0
class Checker(m.MatcherDecoratableVisitor):
    METADATA_DEPENDENCIES = (PositionProvider,)

    def __init__(
        self, path: Path, verbose: bool = False, ignored: Optional[List[str]] = None
    ):
        super().__init__()
        self.path = path
        self.verbose = verbose
        self.ignored = set(ignored or [])
        self.future_division = False
        self.errors = False
        self.stack: List[str] = []

    @m.call_if_inside(m.ImportFrom(module=m.Name("__future__")))
    @m.visit(m.ImportAlias(name=m.Name("division")))
    def import_div(self, node: ImportAlias) -> None:
        self.future_division = True

    @m.visit(m.BinaryOperation(operator=m.Divide()))
    def check_div(self, node: BinaryOperation) -> None:
        if "division" in self.ignored:
            return
        if not self.future_division:
            pos = self.get_metadata(PositionProvider, node).start
            print(
                f"{self.path}:{pos.line}:{pos.column}: division without `from __future__ import division`"
            )
            self.errors = True

    @m.visit(m.Attribute(attr=m.Name("maxint"), value=m.Name("sys")))
    def check_maxint(self, node: Attribute) -> None:
        if "sys.maxint" in self.ignored:
            return
        pos = self.get_metadata(PositionProvider, node).start
        print(f"{self.path}:{pos.line}:{pos.column}: use of sys.maxint")
        self.errors = True

    def visit_ClassDef(self, node: ClassDef) -> None:
        self.stack.append(node.name.value)

    def leave_ClassDef(self, node: ClassDef) -> None:
        self.stack.pop()

    def visit_FunctionDef(self, node: FunctionDef) -> None:
        self.stack.append(node.name.value)

    def leave_FunctionDef(self, node: FunctionDef) -> None:
        self.stack.pop()

    def visit_ClassDef_bases(self, node: "ClassDef") -> None:
        return

    @m.visit(
        m.Call(
            func=m.Attribute(attr=m.Name("assertEquals") | m.Name("assertItemsEqual"))
        )
    )
    def visit_old_assert(self, node: Call) -> None:
        name = ensure_type(node.func, Attribute).attr.value
        if name in self.ignored:
            return
        pos = self.get_metadata(PositionProvider, node).start
        print(f"{self.path}:{pos.line}:{pos.column}: use of {name}")
        self.errors = True
Пример #15
0
from typing import List, Optional, Set, Tuple, Union

import libcst as cst
from libcst import matchers as m

from tornado_async_transformer.helpers import (
    name_attr_possibilities,
    some_version_of,
    with_added_imports,
)

# matchers
gen_return_statement_matcher = m.Raise(
    exc=some_version_of("tornado.gen.Return"))
gen_return_call_with_args_matcher = m.Raise(exc=m.Call(
    func=some_version_of("tornado.gen.Return"), args=[m.AtLeastN(n=1)]))
gen_return_call_matcher = m.Raise(exc=m.Call(
    func=some_version_of("tornado.gen.Return")))
gen_return_matcher = gen_return_statement_matcher | gen_return_call_matcher
gen_sleep_matcher = m.Call(func=some_version_of("gen.sleep"))
gen_task_matcher = m.Call(func=some_version_of("gen.Task"))
gen_coroutine_decorator_matcher = m.Decorator(
    decorator=some_version_of("tornado.gen.coroutine"))
gen_test_coroutine_decorator = m.Decorator(
    decorator=some_version_of("tornado.testing.gen_test"))
coroutine_decorator_matcher = (gen_coroutine_decorator_matcher
                               | gen_test_coroutine_decorator)
coroutine_matcher = m.FunctionDef(
    asynchronous=None,
    decorators=[m.ZeroOrMore(), coroutine_decorator_matcher,
                m.ZeroOrMore()],
Пример #16
0
def is_foreign_key(node: Call) -> bool:
    return m.matches(node, m.Call(func=m.Attribute(attr=m.Name(value="ForeignKey"))))
Пример #17
0
 def leave_Call(self, original_node: Call,
                updated_node: Call) -> BaseExpression:
     if m.matches(updated_node, m.Call(func=m.Name("url"))):
         return Call(args=updated_node.args, func=Name("re_path"))
     return super().leave_Call(original_node, updated_node)
Пример #18
0
class Modernizer(m.MatcherDecoratableTransformer):
    METADATA_DEPENDENCIES = (PositionProvider,)
    # FIXME use a stack of e.g. SimpleStatementLine then proper visit_Import/ImportFrom to store the ssl node

    def __init__(
        self, path: Path, verbose: bool = False, ignored: Optional[List[str]] = None
    ):
        super().__init__()
        self.path = path
        self.verbose = verbose
        self.ignored = set(ignored or [])
        self.errors = False
        self.stack: List[Tuple[str, ...]] = []
        self.annotations: Dict[
            Tuple[str, ...], Comment  # key: tuple of canonical variable name
        ] = {}
        self.python_future_updated_node: Optional[SimpleStatementLine] = None
        self.python_future_imports: Dict[str, str] = {}
        self.python_future_new_imports: Set[str] = set()
        self.builtins_imports: Dict[str, str] = {}
        self.builtins_new_imports: Set[str] = set()
        self.builtins_updated_node: Optional[SimpleStatementLine] = None
        self.future_utils_imports: Dict[str, str] = {}
        self.future_utils_new_imports: Set[str] = set()
        self.future_utils_updated_node: Optional[SimpleStatementLine] = None
        # self.last_import_node: Optional[CSTNode] = None
        self.last_import_node_stmt: Optional[CSTNode] = None

    # @m.call_if_inside(m.ImportFrom(module=m.Name("__future__")))
    # @m.visit(m.ImportAlias() | m.ImportStar())
    # def import_python_future_check(self, node: Union[ImportAlias, ImportStar]) -> None:
    #     self.add_import(self.python_future_imports, node)

    # @m.leave(m.ImportFrom(module=m.Name("__future__")))
    # def import_python_future_modify(
    #     self, original_node: ImportFrom, updated_node: ImportFrom
    # ) -> Union[BaseSmallStatement, RemovalSentinel]:
    #     return updated_node

    @m.call_if_inside(m.ImportFrom(module=m.Name("builtins")))
    @m.visit(m.ImportAlias() | m.ImportStar())
    def import_builtins_check(self, node: Union[ImportAlias, ImportStar]) -> None:
        self.add_import(self.builtins_imports, node)

    # @m.leave(m.ImportFrom(module=m.Name("builtins")))
    # def builtins_modify(
    #     self, original_node: ImportFrom, updated_node: ImportFrom
    # ) -> Union[BaseSmallStatement, RemovalSentinel]:
    #     return updated_node

    @m.call_if_inside(
        m.ImportFrom(module=m.Attribute(value=m.Name("future"), attr=m.Name("utils")))
    )
    @m.visit(m.ImportAlias() | m.ImportStar())
    def import_future_utils_check(self, node: Union[ImportAlias, ImportStar]) -> None:
        self.add_import(self.future_utils_imports, node)

    # @m.leave(
    #     m.ImportFrom(module=m.Attribute(value=m.Name("future"), attr=m.Name("utils")))
    # )
    # def future_utils_modify(
    #     self, original_node: ImportFrom, updated_node: ImportFrom
    # ) -> Union[BaseSmallStatement, RemovalSentinel]:
    #     return updated_node

    @staticmethod
    def add_import(
        imports: Dict[str, str], node: Union[ImportAlias, ImportStar]
    ) -> None:
        if isinstance(node, ImportAlias):
            imports[node.name.value] = (
                node.asname.name.value if node.asname else node.name.value
            )
        else:
            imports["*"] = "*"

    # @m.call_if_not_inside(m.BaseCompoundStatement())
    # def visit_Import(self, node: Import) -> Optional[bool]:
    #     self.last_import_node = node
    #     return None

    # @m.call_if_not_inside(m.BaseCompoundStatement())
    # def visit_ImportFrom(self, node: ImportFrom) -> Optional[bool]:
    #     self.last_import_node = node
    #     return None

    @m.call_if_not_inside(m.ClassDef() | m.FunctionDef() | m.If())
    def visit_SimpleStatementLine(self, node: SimpleStatementLine) -> Optional[bool]:
        for n in node.body:
            if m.matches(n, m.Import() | m.ImportFrom()):
                self.last_import_node_stmt = node
        return None

    @m.call_if_not_inside(m.ClassDef() | m.FunctionDef() | m.If())
    def leave_SimpleStatementLine(
        self, original_node: SimpleStatementLine, updated_node: SimpleStatementLine
    ) -> Union[BaseStatement, RemovalSentinel]:
        for n in updated_node.body:
            if m.matches(n, m.ImportFrom(module=m.Name("__future__"))):
                self.python_future_updated_node = updated_node
            elif m.matches(n, m.ImportFrom(module=m.Name("builtins"))):
                self.builtins_updated_node = updated_node
            elif m.matches(
                n,
                m.ImportFrom(
                    module=m.Attribute(value=m.Name("future"), attr=m.Name("utils"))
                ),
            ):
                self.future_utils_updated_node = updated_node
        return updated_node

    # @m.visit(
    #     m.AllOf(
    #         m.SimpleStatementLine(),
    #         m.MatchIfTrue(
    #             lambda node: any(m.matches(c, m.Assign()) for c in node.children)
    #         ),
    #         m.MatchIfTrue(
    #             lambda node: "# type:" in node.trailing_whitespace.comment.value
    #         ),
    #     )
    # )
    # def visit_assign(self, node: SimpleStatementSuite) -> None:
    #     return None

    def visit_Param(self, node: Param) -> Optional[bool]:
        class Visitor(m.MatcherDecoratableVisitor):
            def __init__(self):
                super().__init__()
                self.ptype: Optional[str] = None

            def visit_TrailingWhitespace_comment(
                self, node: "TrailingWhitespace"
            ) -> None:
                if node.comment and "type:" in node.comment.value:
                    mo = re.match(r"#\s*type:\s*(\S*)", node.comment.value)
                    self.ptype = mo.group(1) if mo else None
                return None

        v = Visitor()
        node.visit(v)
        if self.verbose:
            pos = self.get_metadata(PositionProvider, node).start
            print(
                f"{self.path}:{pos.line}:{pos.column}: parameter {node.name.value}: {v.ptype or 'unknown type'}"
            )
        return None

    @m.visit(m.SimpleStatementLine())
    def visit_simple_stmt(self, node: SimpleStatementLine) -> None:
        assign = None
        for c in node.children:
            if m.matches(c, m.Assign()):
                assign = ensure_type(c, Assign)
        if assign:
            if m.MatchIfTrue(
                lambda n: n.trailing_whitespace.comment
                and "type:" in n.trailing_whitespace.comment.value
            ):

                class TypingVisitor(m.MatcherDecoratableVisitor):
                    def __init__(self):
                        super().__init__()
                        self.vtype = None

                    def visit_TrailingWhitespace_comment(
                        self, node: "TrailingWhitespace"
                    ) -> None:
                        if node.comment:
                            mo = re.match(r"#\s*type:\s*(\S*)", node.comment.value)
                            if mo:
                                vtype = mo.group(1)
                        return None

                tv = TypingVisitor()
                node.visit(tv)
                vtype = tv.vtype
            else:
                vtype = None

            class NameVisitor(m.MatcherDecoratableVisitor):
                def __init__(self):
                    super().__init__()
                    self.names: List[str] = []

                def visit_Name(self, node: Name) -> Optional[bool]:
                    self.names.append(node.value)
                    return None

            if self.verbose:
                pos = self.get_metadata(PositionProvider, node).start
                for target in assign.targets:
                    v = NameVisitor()
                    target.visit(v)
                    for name in v.names:
                        print(
                            f"{self.path}:{pos.line}:{pos.column}: variable {name}: {vtype or 'unknown type'}"
                        )

    def visit_FunctionDef_body(self, node: FunctionDef) -> None:
        class Visitor(m.MatcherDecoratableVisitor):
            def __init__(self):
                super().__init__()

            def visit_EmptyLine_comment(self, node: "EmptyLine") -> None:
                # FIXME too many matches on test_param_02
                if not node.comment:
                    return
                # TODO: use comment.value
                return None

        v = Visitor()
        node.visit(v)
        return None

    map_matcher = m.Call(
        func=m.Name("filter") | m.Name("map") | m.Name("zip") | m.Name("range")
    )

    @m.visit(map_matcher)
    def visit_map(self, node: Call) -> None:
        func_name = ensure_type(node.func, Name).value
        if func_name not in self.builtins_imports:
            self.builtins_new_imports.add(func_name)

    @m.call_if_not_inside(
        m.Call(
            func=m.Name("list")
            | m.Name("set")
            | m.Name("tuple")
            | m.Attribute(attr=m.Name("join"))
        )
        | m.CompFor()
        | m.For()
    )
    @m.leave(map_matcher)
    def fix_map(self, original_node: Call, updated_node: Call) -> BaseExpression:
        # TODO test with CompFor etc.
        # TODO improve join test
        func_name = ensure_type(updated_node.func, Name).value
        if func_name not in self.builtins_imports:
            updated_node = Call(func=Name("list"), args=[Arg(updated_node)])
        return updated_node

    @m.visit(m.Call(func=m.Name("xrange") | m.Name("raw_input")))
    def visit_xrange(self, node: Call) -> None:
        orig_func_name = ensure_type(node.func, Name).value
        func_name = "range" if orig_func_name == "xrange" else "input"
        if func_name not in self.builtins_imports:
            self.builtins_new_imports.add(func_name)

    @m.leave(m.Call(func=m.Name("xrange") | m.Name("raw_input")))
    def fix_xrange(self, original_node: Call, updated_node: Call) -> BaseExpression:
        orig_func_name = ensure_type(updated_node.func, Name).value
        func_name = "range" if orig_func_name == "xrange" else "input"
        return updated_node.with_changes(func=Name(func_name))

    iter_matcher = m.Call(
        func=m.Attribute(
            attr=m.Name("iterkeys") | m.Name("itervalues") | m.Name("iteritems")
        )
    )

    @m.visit(iter_matcher)
    def visit_iter(self, node: Call) -> None:
        func_name = ensure_type(node.func, Attribute).attr.value
        if func_name not in self.future_utils_imports:
            self.future_utils_new_imports.add(func_name)

    @m.leave(iter_matcher)
    def fix_iter(self, original_node: Call, updated_node: Call) -> BaseExpression:
        attribute = ensure_type(updated_node.func, Attribute)
        func_name = attribute.attr
        dict_name = attribute.value
        return updated_node.with_changes(func=func_name, args=[Arg(dict_name)])

    not_iter_matcher = m.Call(
        func=m.Attribute(attr=m.Name("keys") | m.Name("values") | m.Name("items"))
    )

    @m.call_if_not_inside(
        m.Call(
            func=m.Name("list")
            | m.Name("set")
            | m.Name("tuple")
            | m.Attribute(attr=m.Name("join"))
        )
        | m.CompFor()
        | m.For()
    )
    @m.leave(not_iter_matcher)
    def fix_not_iter(self, original_node: Call, updated_node: Call) -> BaseExpression:
        updated_node = Call(func=Name("list"), args=[Arg(updated_node)])
        return updated_node

    @m.call_if_not_inside(m.Import() | m.ImportFrom())
    @m.leave(m.Name(value="unicode"))
    def fix_unicode(self, original_node: Name, updated_node: Name) -> BaseExpression:
        value = "text_type"
        if value not in self.future_utils_imports:
            self.future_utils_new_imports.add(value)
        return updated_node.with_changes(value=value)

    def leave_Module(self, original_node: Module, updated_node: Module) -> Module:
        updated_node = self.update_imports(
            original_node,
            updated_node,
            "builtins",
            self.builtins_updated_node,
            self.builtins_imports,
            self.builtins_new_imports,
            True,
        )
        updated_node = self.update_imports(
            original_node,
            updated_node,
            "future.utils",
            self.future_utils_updated_node,
            self.future_utils_imports,
            self.future_utils_new_imports,
            False,
        )
        return updated_node

    def update_imports(
        self,
        original_module: Module,
        updated_module: Module,
        import_name: str,
        updated_import_node: SimpleStatementLine,
        current_imports: Dict[str, str],
        new_imports: Set[str],
        noqa: bool,
    ) -> Module:
        if not new_imports:
            return updated_module
        noqa_comment = "  # noqa" if noqa else ""
        if not updated_import_node:
            i = -1
            blank_lines = "\n\n"
            if self.last_import_node_stmt:
                blank_lines = ""
                for i, (original, updated) in enumerate(
                    zip(original_module.body, updated_module.body)
                ):
                    if original is self.last_import_node_stmt:
                        break
            stmt = parse_module(
                f"from {import_name} import {', '.join(sorted(new_imports))}{noqa_comment}\n{blank_lines}",
                config=updated_module.config_for_parsing,
            )
            body = list(updated_module.body)
            self.last_import_node_stmt = stmt
            return updated_module.with_changes(
                body=body[: i + 1] + stmt.children + body[i + 1 :]
            )
        else:
            if "*" not in current_imports:
                current_imports_set = {
                    f"{k}" if k == v else f"{k} as {v}"
                    for k, v in current_imports.items()
                }
                stmt = parse_statement(
                    f"from {import_name} import {', '.join(sorted(new_imports | current_imports_set))}{noqa_comment}"
                )
                return updated_module.deep_replace(updated_import_node, stmt)
                # for i, (original, updated) in enumerate(
                #     zip(original_module.body, updated_module.body)
                # ):
                #     if original is original_import_node:
                #         body = list(updated_module.body)
                #         return updated_module.with_changes(
                #             body=body[:i] + [stmt] + body[i + 1 :]
                #         )
        return updated_module
Пример #19
0
    def visit_Call(self, node: cst.Call) -> None:
        if m.matches(
                node,
                m.Call(
                    func=m.Name("tuple") | m.Name("list") | m.Name("set")
                    | m.Name("dict"),
                    args=[m.Arg(value=m.List() | m.Tuple())],
                ),
        ) or m.matches(
                node,
                m.Call(func=m.Name("tuple") | m.Name("list") | m.Name("dict"),
                       args=[]),
        ):

            pairs_matcher = m.ZeroOrMore(
                m.Element(m.Tuple(
                    elements=[m.DoNotCare(), m.DoNotCare()]))
                | m.Element(m.List(
                    elements=[m.DoNotCare(), m.DoNotCare()])))

            exp = cst.ensure_type(node, cst.Call)
            call_name = cst.ensure_type(exp.func, cst.Name).value

            # If this is a empty call, it's an Unnecessary Call where we rewrite the call
            # to literal, except set().
            if not exp.args:
                elements = []
                message_formatter = UNNCESSARY_CALL
            else:
                arg = exp.args[0].value
                elements = cst.ensure_type(
                    arg, cst.List
                    if isinstance(arg, cst.List) else cst.Tuple).elements
                message_formatter = UNNECESSARY_LITERAL

            if call_name == "tuple":
                new_node = cst.Tuple(elements=elements)
            elif call_name == "list":
                new_node = cst.List(elements=elements)
            elif call_name == "set":
                # set() doesn't have an equivelant literal call. If it was
                # matched here, it's an unnecessary literal suggestion.
                if len(elements) == 0:
                    self.report(
                        node,
                        UNNECESSARY_LITERAL.format(func=call_name),
                        replacement=node.deep_replace(
                            node, cst.Call(func=cst.Name("set"))),
                    )
                    return
                new_node = cst.Set(elements=elements)
            elif len(elements) == 0 or m.matches(
                    exp.args[0].value,
                    m.Tuple(elements=[pairs_matcher])
                    | m.List(elements=[pairs_matcher]),
            ):
                new_node = cst.Dict(elements=[(
                    lambda val: cst.DictElement(val.elements[
                        0].value, val.elements[1].value))(cst.ensure_type(
                            ele.value,
                            cst.Tuple if isinstance(ele.value, cst.Tuple
                                                    ) else cst.List,
                        )) for ele in elements])
            else:
                # Unrecoginized form
                return

            self.report(
                node,
                message_formatter.format(func=call_name),
                replacement=node.deep_replace(node, new_node),
            )
Пример #20
0
 def leave_Call(self, original_node: Call,
                updated_node: Call) -> BaseExpression:
     if self.is_entity_imported and m.matches(
             updated_node, m.Call(func=m.Name(self.old_name))):
         return Name(self.new_name)
     return super().leave_Call(original_node, updated_node)
Пример #21
0
class ShedFixers(VisitorBasedCodemodCommand):
    """Fix a variety of small problems.

    Replaces `raise NotImplemented` with `raise NotImplementedError`,
    and converts always-failing assert statements to explicit `raise` statements.

    Also includes code closely modelled on pybetter's fixers, because it's
    considerably faster to run all transforms in a single pass if possible.
    """

    DESCRIPTION = "Fix a variety of style, performance, and correctness issues."

    @m.call_if_inside(m.Raise(exc=m.Name(value="NotImplemented")))
    def leave_Name(self, _, updated_node):  # noqa
        return updated_node.with_changes(value="NotImplementedError")

    def leave_Assert(self, _, updated_node):  # noqa
        test_code = cst.Module("").code_for_node(updated_node.test)
        try:
            test_literal = literal_eval(test_code)
        except Exception:
            return updated_node
        if test_literal:
            return cst.RemovalSentinel.REMOVE
        if updated_node.msg is None:
            return cst.Raise(cst.Name("AssertionError"))
        return cst.Raise(
            cst.Call(cst.Name("AssertionError"),
                     args=[cst.Arg(updated_node.msg)]))

    @m.leave(
        m.ComparisonTarget(comparator=oneof_names("None", "False", "True"),
                           operator=m.Equal()))
    def convert_none_cmp(self, _, updated_node):
        """Inspired by Pybetter."""
        return updated_node.with_changes(operator=cst.Is())

    @m.leave(
        m.UnaryOperation(
            operator=m.Not(),
            expression=m.Comparison(
                comparisons=[m.ComparisonTarget(operator=m.In())]),
        ))
    def replace_not_in_condition(self, _, updated_node):
        """Also inspired by Pybetter."""
        expr = cst.ensure_type(updated_node.expression, cst.Comparison)
        return cst.Comparison(
            left=expr.left,
            lpar=updated_node.lpar,
            rpar=updated_node.rpar,
            comparisons=[
                expr.comparisons[0].with_changes(operator=cst.NotIn())
            ],
        )

    @m.leave(
        m.Call(
            lpar=[m.AtLeastN(n=1, matcher=m.LeftParen())],
            rpar=[m.AtLeastN(n=1, matcher=m.RightParen())],
        ))
    def remove_pointless_parens_around_call(self, _, updated_node):
        # This is *probably* valid, but we might have e.g. a multi-line parenthesised
        # chain of attribute accesses ("fluent interface"), where we need the parens.
        noparens = updated_node.with_changes(lpar=[], rpar=[])
        try:
            compile(self.module.code_for_node(noparens), "<string>", "eval")
            return noparens
        except SyntaxError:
            return updated_node

    # The following methods fix https://pypi.org/project/flake8-comprehensions/

    @m.leave(m.Call(func=m.Name("list"), args=[m.Arg(m.GeneratorExp())]))
    def replace_generator_in_call_with_comprehension(self, _, updated_node):
        """Fix flake8-comprehensions C400-402 and 403-404.

        C400-402: Unnecessary generator - rewrite as a <list/set/dict> comprehension.
        Note that set and dict conversions are handled by pyupgrade!
        """
        return cst.ListComp(elt=updated_node.args[0].value.elt,
                            for_in=updated_node.args[0].value.for_in)

    @m.leave(
        m.Call(func=m.Name("list"), args=[m.Arg(m.ListComp(), star="")])
        | m.Call(func=m.Name("set"), args=[m.Arg(m.SetComp(), star="")])
        | m.Call(
            func=m.Name("list"),
            args=[m.Arg(m.Call(func=oneof_names("sorted", "list")), star="")],
        ))
    def replace_unnecessary_list_around_sorted(self, _, updated_node):
        """Fix flake8-comprehensions C411 and C413.

        Unnecessary <list/reversed> call around sorted().

        Also covers C411 Unnecessary list call around list comprehension
        for lists and sets.
        """
        return updated_node.args[0].value

    @m.leave(
        m.Call(
            func=m.Name("reversed"),
            args=[m.Arg(m.Call(func=m.Name("sorted")), star="")],
        ))
    def replace_unnecessary_reversed_around_sorted(self, _, updated_node):
        """Fix flake8-comprehensions C413.

        Unnecessary reversed call around sorted().
        """
        call = updated_node.args[0].value
        args = list(call.args)
        for i, arg in enumerate(args):
            if m.matches(arg.keyword, m.Name("reverse")):
                try:
                    val = bool(
                        literal_eval(self.module.code_for_node(arg.value)))
                except Exception:
                    args[i] = arg.with_changes(
                        value=cst.UnaryOperation(cst.Not(), arg.value))
                else:
                    if not val:
                        args[i] = arg.with_changes(value=cst.Name("True"))
                    else:
                        del args[i]
                        args[i - 1] = remove_trailing_comma(args[i - 1])
                break
        else:
            args.append(
                cst.Arg(keyword=cst.Name("reverse"), value=cst.Name("True")))
        return call.with_changes(args=args)

    _sets = oneof_names("set", "frozenset")
    _seqs = oneof_names("list", "reversed", "sorted", "tuple")

    @m.leave(
        m.Call(func=_sets, args=[m.Arg(m.Call(func=_sets | _seqs), star="")])
        | m.Call(
            func=oneof_names("list", "tuple"),
            args=[m.Arg(m.Call(func=oneof_names("list", "tuple")), star="")],
        )
        | m.Call(
            func=m.Name("sorted"),
            args=[m.Arg(m.Call(func=_seqs), star=""),
                  m.ZeroOrMore()],
        ))
    def replace_unnecessary_nested_calls(self, _, updated_node):
        """Fix flake8-comprehensions C414.

        Unnecessary <list/reversed/sorted/tuple> call within <list/set/sorted/tuple>()..
        """
        return updated_node.with_changes(
            args=[cst.Arg(updated_node.args[0].value.args[0].value)] +
            list(updated_node.args[1:]), )

    @m.leave(
        m.Call(
            func=oneof_names("reversed", "set", "sorted"),
            args=[
                m.Arg(m.Subscript(slice=[m.SubscriptElement(ALL_ELEMS_SLICE)]))
            ],
        ))
    def replace_unnecessary_subscript_reversal(self, _, updated_node):
        """Fix flake8-comprehensions C415.

        Unnecessary subscript reversal of iterable within <reversed/set/sorted>().
        """
        return updated_node.with_changes(
            args=[cst.Arg(updated_node.args[0].value.value)], )

    @m.leave(
        multi(
            m.ListComp,
            m.SetComp,
            elt=m.Name(),
            for_in=m.CompFor(target=m.Name(),
                             ifs=[],
                             inner_for_in=None,
                             asynchronous=None),
        ))
    def replace_unnecessary_listcomp_or_setcomp(self, _, updated_node):
        """Fix flake8-comprehensions C416.

        Unnecessary <list/set> comprehension - rewrite using <list/set>().
        """
        if updated_node.elt.value == updated_node.for_in.target.value:
            func = cst.Name(
                "list" if isinstance(updated_node, cst.ListComp) else "set")
            return cst.Call(func=func,
                            args=[cst.Arg(updated_node.for_in.iter)])
        return updated_node

    @m.leave(m.Subscript(oneof_names("Union", "Literal")))
    def reorder_union_literal_contents_none_last(self, _, updated_node):
        subscript = list(updated_node.slice)
        try:
            subscript.sort(key=lambda elt: elt.slice.value.value == "None")
            subscript[-1] = remove_trailing_comma(subscript[-1])
            return updated_node.with_changes(slice=subscript)
        except Exception:  # Single-element literals are not slices, etc.
            return updated_node

    @m.call_if_inside(m.Annotation(annotation=m.BinaryOperation()))
    @m.leave(
        m.BinaryOperation(
            left=m.Name("None") | m.BinaryOperation(),
            operator=m.BitOr(),
            right=m.DoNotCare(),
        ))
    def reorder_union_operator_contents_none_last(self, _, updated_node):
        def _has_none(node):
            if m.matches(node, m.Name("None")):
                return True
            elif m.matches(node, m.BinaryOperation()):
                return _has_none(node.left) or _has_none(node.right)
            else:
                return False

        node_left = updated_node.left
        if _has_none(node_left):
            return updated_node.with_changes(left=updated_node.right,
                                             right=node_left)
        else:
            return updated_node

    @m.leave(m.Subscript(value=m.Name("Literal")))
    def flatten_literal_subscript(self, _, updated_node):
        new_slice = []
        for item in updated_node.slice:
            if m.matches(item.slice.value, m.Subscript(m.Name("Literal"))):
                new_slice += item.slice.value.slice
            else:
                new_slice.append(item)
        return updated_node.with_changes(slice=new_slice)

    @m.leave(m.Subscript(value=m.Name("Union")))
    def flatten_union_subscript(self, _, updated_node):
        new_slice = []
        has_none = False
        for item in updated_node.slice:
            if m.matches(item.slice.value, m.Subscript(m.Name("Optional"))):
                new_slice += item.slice.value.slice  # peel off "Optional"
                has_none = True
            elif m.matches(item.slice.value,
                           m.Subscript(m.Name("Union"))) and m.matches(
                               updated_node.value, item.slice.value.value):
                new_slice += item.slice.value.slice  # peel off "Union" or "Literal"
            elif m.matches(item.slice.value, m.Name("None")):
                has_none = True
            else:
                new_slice.append(item)
        if has_none:
            new_slice.append(
                cst.SubscriptElement(slice=cst.Index(cst.Name("None"))))
        return updated_node.with_changes(slice=new_slice)

    @m.leave(m.Else(m.IndentedBlock([m.SimpleStatementLine([m.Pass()])])))
    def discard_empty_else_blocks(self, _, updated_node):
        # An `else: pass` block can always simply be discarded, and libcst ensures
        # that an Else node can only ever occur attached to an If, While, For, or Try
        # node; in each case `None` is the valid way to represent "no else block".
        if m.findall(updated_node, m.Comment()):
            return updated_node  # If there are any comments, keep the node
        return cst.RemoveFromParent()

    @m.leave(
        m.Lambda(params=m.MatchIfTrue(lambda node: (
            node.star_kwarg is None and not node.kwonly_params and not node.
            posonly_params and isinstance(node.star_arg, cst.MaybeSentinel) and
            all(param.default is None for param in node.params)))))
    def remove_lambda_indirection(self, _, updated_node):
        same_args = [
            m.Arg(m.Name(param.name.value), star="", keyword=None)
            for param in updated_node.params.params
        ]
        if m.matches(updated_node.body, m.Call(args=same_args)):
            return cst.ensure_type(updated_node.body, cst.Call).func
        return updated_node

    @m.leave(
        m.BooleanOperation(
            left=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]),
            operator=m.Or(),
            right=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]),
        ))
    def collapse_isinstance_checks(self, _, updated_node):
        left_target, left_type = updated_node.left.args
        right_target, right_type = updated_node.right.args
        if left_target.deep_equals(right_target):
            merged_type = cst.Arg(
                cst.Tuple([
                    cst.Element(left_type.value),
                    cst.Element(right_type.value)
                ]))
            return updated_node.left.with_changes(
                args=[left_target, merged_type])
        return updated_node
Пример #22
0
 def test_at_most_n_matcher_no_args_true(self) -> None:
     # Match a function call to "foo" with at most two arguments.
     self.assertTrue(
         matches(
             cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")),)),
             m.Call(func=m.Name("foo"), args=(m.AtMostN(n=2),)),
         )
     )
     # Match a function call to "foo" with at most two arguments.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))),
             ),
             m.Call(func=m.Name("foo"), args=(m.AtMostN(n=2),)),
         )
     )
     # Match a function call to "foo" with at most six arguments, the last
     # one being the integer 1.
     self.assertTrue(
         matches(
             cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")),)),
             m.Call(
                 func=m.Name("foo"), args=[m.AtMostN(n=5), m.Arg(m.Integer("1"))]
             ),
         )
     )
     # Match a function call to "foo" with at most six arguments, the last
     # one being the integer 1.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.AtMostN(n=5), m.Arg(m.Integer("2")))
             ),
         )
     )
     # Match a function call to "foo" with at most six arguments, the first
     # one being the integer 1.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtMostN(n=5))
             ),
         )
     )
     # Match a function call to "foo" with at most six arguments, the first
     # one being the integer 1.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))),
             ),
             m.Call(func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrOne())),
         )
     )
Пример #23
0
    def leave_Call(  # noqa: C901
            self, original_node: cst.Call,
            updated_node: cst.Call) -> cst.BaseExpression:
        # Lets figure out if this is a "".format() call
        extraction = self.extract(
            updated_node,
            m.Call(func=m.Attribute(
                value=m.SaveMatchedNode(m.SimpleString(), "string"),
                attr=m.Name("format"),
            )),
        )
        if extraction is not None:
            fstring: List[cst.BaseFormattedStringContent] = []
            inserted_sequence: int = 0
            stringnode = cst.ensure_type(extraction["string"],
                                         cst.SimpleString)
            tokens = _get_tokens(stringnode.raw_value)
            for (literal_text, field_name, format_spec, conversion) in tokens:
                if literal_text:
                    fstring.append(cst.FormattedStringText(literal_text))
                if field_name is None:
                    # This is not a format-specification
                    continue
                if format_spec is not None and len(format_spec) > 0:
                    # TODO: This is supportable since format specs are compatible
                    # with f-string format specs, but it would require matching
                    # format specifier expansions.
                    self.warn(
                        f"Unsupported format_spec {format_spec} in format() call"
                    )
                    return updated_node

                # Auto-insert field sequence if it is empty
                if field_name == "":
                    field_name = str(inserted_sequence)
                    inserted_sequence += 1
                expr = _find_expr_from_field_name(field_name,
                                                  updated_node.args)
                if expr is None:
                    # Most likely they used * expansion in a format.
                    self.warn(
                        f"Unsupported field_name {field_name} in format() call"
                    )
                    return updated_node

                # Verify that we don't have any comments or newlines. Comments aren't
                # allowed in f-strings, and newlines need parenthesization. We can
                # have formattedstrings inside other formattedstrings, but I chose not
                # to doeal with that for now.
                if self.findall(expr, m.Comment()):
                    # We could strip comments, but this is a formatting change so
                    # we choose not to for now.
                    self.warn(f"Unsupported comment in format() call")
                    return updated_node
                if self.findall(expr, m.FormattedString()):
                    self.warn(f"Unsupported f-string in format() call")
                    return updated_node
                if self.findall(expr, m.Await()):
                    # This is fixed in 3.7 but we don't currently have a flag
                    # to enable/disable it.
                    self.warn(f"Unsupported await in format() call")
                    return updated_node

                # Stripping newlines is effectively a format-only change.
                expr = cst.ensure_type(
                    expr.visit(StripNewlinesTransformer(self.context)),
                    cst.BaseExpression,
                )

                # Try our best to swap quotes on any strings that won't fit
                expr = cst.ensure_type(
                    expr.visit(
                        SwitchStringQuotesTransformer(self.context,
                                                      stringnode.quote[0])),
                    cst.BaseExpression,
                )

                # Verify that the resulting expression doesn't have a backslash
                # in it.
                raw_expr_string = self.module.code_for_node(expr)
                if "\\" in raw_expr_string:
                    self.warn(f"Unsupported backslash in format expression")
                    return updated_node

                # For safety sake, if this is a dict/set or dict/set comprehension,
                # wrap it in parens so that it doesn't accidentally create an
                # escape.
                if (raw_expr_string.startswith("{")
                        or raw_expr_string.endswith("}")) and (not expr.lpar or
                                                               not expr.rpar):
                    expr = expr.with_changes(lpar=[cst.LeftParen()],
                                             rpar=[cst.RightParen()])

                # Verify that any strings we insert don't have the same quote
                quote_gatherer = StringQuoteGatherer(self.context)
                expr.visit(quote_gatherer)
                for stringend in quote_gatherer.stringends:
                    if stringend in stringnode.quote:
                        self.warn(
                            f"Cannot embed string with same quote from format() call"
                        )
                        return updated_node

                fstring.append(
                    cst.FormattedStringExpression(expression=expr,
                                                  conversion=conversion))
            return cst.FormattedString(
                parts=fstring,
                start=f"f{stringnode.prefix}{stringnode.quote}",
                end=stringnode.quote,
            )

        return updated_node
Пример #24
0
 def test_complex_matcher_false(self) -> None:
     # Fail to match since this is a Call, not a FunctionDef.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.FunctionDef(),
         )
     )
     # Fail to match a function named "bar".
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(m.Name("bar")),
         )
     )
     # Fail to match a function named "foo" with two arguments.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(func=m.Name("foo"), args=(m.Arg(), m.Arg())),
         )
     )
     # Fail to match a function named "foo" with three integer arguments
     # 3, 2, 1.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.Arg(m.Integer("3")),
                     m.Arg(m.Integer("2")),
                     m.Arg(m.Integer("1")),
                 ),
             ),
         )
     )
     # Fail to match a function named "foo" with three arguments, the last one
     # being the integer 1.
     self.assertFalse(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.DoNotCare(), m.DoNotCare(), m.Arg(m.Integer("1"))),
             ),
         )
     )
Пример #25
0
class DeprecationWarningsCommand(VisitorBasedCodemodCommand):

    DESCRIPTION: str = "Rename deprecated methods"

    deprecated_symbols_map: List[Tuple[str, Union[str, Tuple[str, str]]]] = [
        ("BitmapFromImage", "Bitmap"),
        ("ImageFromStream", "Image"),
        ("EmptyIcon", "Icon"),
        ("DateTimeFromDMY", ("DateTime", "FromDMY")),
    ]
    matchers_short_map = {
        (value, matchers.Call(func=matchers.Name(value=value)), renamed)
        for value, renamed in deprecated_symbols_map
    }
    matchers_full_map = {(
        matchers.Call(func=matchers.Attribute(
            value=matchers.Name(value="wx"), attr=matchers.Name(value=value))),
        renamed,
    )
                         for value, renamed in deprecated_symbols_map}

    def __init__(self, context: CodemodContext):
        super().__init__(context)

        self.wx_imports: Set[str] = set()

    def visit_Module(self, node: cst.Module) -> None:
        # Collect current list of imports
        gatherer = GatherImportsVisitor(self.context)

        node.visit(gatherer)

        # Store list of symbols imported from wx package
        self.wx_imports = gatherer.object_mapping.get("wx", set())

    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.Call:
        # Matches calls with symbols without the wx prefix
        for symbol, matcher, renamed in self.matchers_short_map:
            if symbol in self.wx_imports and matchers.matches(
                    updated_node, matcher):
                # Remove the symbol's import
                RemoveImportsVisitor.remove_unused_import_by_node(
                    self.context, original_node)

                # Add import of top level wx package
                AddImportsVisitor.add_needed_import(self.context, "wx")

                # Return updated node
                if isinstance(renamed, tuple):
                    return updated_node.with_changes(func=cst.Attribute(
                        value=cst.Attribute(value=cst.Name(value="wx"),
                                            attr=cst.Name(value=renamed[0])),
                        attr=cst.Name(value=renamed[1]),
                    ))

                return updated_node.with_changes(func=cst.Attribute(
                    value=cst.Name(value="wx"), attr=cst.Name(value=renamed)))

        # Matches full calls like wx.MySymbol
        for matcher, renamed in self.matchers_full_map:
            if matchers.matches(updated_node, matcher):

                if isinstance(renamed, tuple):
                    return updated_node.with_changes(func=cst.Attribute(
                        value=cst.Attribute(value=cst.Name(value="wx"),
                                            attr=cst.Name(value=renamed[0])),
                        attr=cst.Name(value=renamed[1]),
                    ))

                return updated_node.with_changes(
                    func=updated_node.func.with_changes(attr=cst.Name(
                        value=renamed)))

        # Returns updated node
        return updated_node
Пример #26
0
 def test_complex_matcher_true(self) -> None:
     # Match on any Call, not caring about arguments.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(),
         )
     )
     # Match on any Call to a function named "foo".
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(m.Name("foo")),
         )
     )
     # Match on any Call to a function named "foo" with three arguments.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(func=m.Name("foo"), args=(m.Arg(), m.Arg(), m.Arg())),
         )
     )
     # Match any Call to a function named "foo" with three integer arguments.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.Arg(m.Integer()), m.Arg(m.Integer()), m.Arg(m.Integer())),
             ),
         )
     )
     # Match any Call to a function named "foo" with integer arguments 1, 2, 3.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.Arg(m.Integer("1")),
                     m.Arg(m.Integer("2")),
                     m.Arg(m.Integer("3")),
                 ),
             ),
         )
     )
     # Match any Call to a function named "foo" with three arguments, the last one
     # being the integer 3.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.DoNotCare(), m.DoNotCare(), m.Arg(m.Integer("3"))),
             ),
         )
     )
Пример #27
0
def is_one_to_one_field(node: Call) -> bool:
    return m.matches(
        node,
        m.Call(func=m.Attribute(attr=m.Name(value="OneToOneField"))),
    )
Пример #28
0
 def test_zero_or_more_matcher_no_args_true(self) -> None:
     # Match a function call to "foo" with any number of arguments as
     # long as the first one is an integer with the value 1.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrMore())
             ),
         )
     )
     # Match a function call to "foo" with any number of arguments as
     # long as one of them is an integer with the value 1.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.ZeroOrMore(), m.Arg(m.Integer("1")), m.ZeroOrMore()),
             ),
         )
     )
     # Match a function call to "foo" with any number of arguments as
     # long as one of them is an integer with the value 2.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.ZeroOrMore(), m.Arg(m.Integer("2")), m.ZeroOrMore()),
             ),
         )
     )
     # Match a function call to "foo" with any number of arguments as
     # long as one of them is an integer with the value 3.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(m.ZeroOrMore(), m.Arg(m.Integer("3")), m.ZeroOrMore()),
             ),
         )
     )
     # Match a function call to "foo" with any number of arguments as
     # long as the last one is an integer with the value 3.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"), args=(m.ZeroOrMore(), m.Arg(m.Integer("3")))
             ),
         )
     )
     # Match a function call to "foo" with any number of arguments as
     # long as there are two arguments with the values 1 and 3 anywhere
     # in the argument list, respecting order.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.ZeroOrMore(),
                     m.Arg(m.Integer("1")),
                     m.ZeroOrMore(),
                     m.Arg(m.Integer("3")),
                     m.ZeroOrMore(),
                 ),
             ),
         )
     )
     # Match a function call to "foo" with any number of arguments as
     # long as there are three arguments with the values 1, 2 and 3 anywhere
     # in the argument list, respecting order.
     self.assertTrue(
         matches(
             cst.Call(
                 func=cst.Name("foo"),
                 args=(
                     cst.Arg(cst.Integer("1")),
                     cst.Arg(cst.Integer("2")),
                     cst.Arg(cst.Integer("3")),
                 ),
             ),
             m.Call(
                 func=m.Name("foo"),
                 args=(
                     m.ZeroOrMore(),
                     m.Arg(m.Integer("1")),
                     m.ZeroOrMore(),
                     m.Arg(m.Integer("2")),
                     m.ZeroOrMore(),
                     m.Arg(m.Integer("3")),
                     m.ZeroOrMore(),
                 ),
             ),
         )
     )
Пример #29
0
 def leave_Call(self, original_node: Call,
                updated_node: Call) -> BaseExpression:
     matcher = self.name_matcher
     if m.matches(updated_node, m.Call(func=matcher)):
         return Name(self.new_name)
     return super().leave_Call(original_node, updated_node)
Пример #30
0
class DatetimeUtcnow_(VisitorBasedCodemodCommand):
	
	DESCRIPTION: str = "Converts from datetime.utcnow() to datetime.utc()"
	
	timezone_utc_matcher = m.Arg(
			value=m.Attribute(
					value=m.Name(value="timezone"), attr=m.Name(value="utc")
			),
			keyword=m.Name(value="tzinfo"),
	)
	
	utc_matcher = m.Arg(
			value=m.OneOf(
					m.Name(value="utc"),
					m.Name(value="UTC"),
					m.Attribute(value=m.Name(value="pytz",), attr=m.Name(value="UTC")),
			),
			keyword=m.Name(value="tzinfo"),
	)
	
	datetime_utcnow_matcher = m.Call(
			func=m.Attribute(
					value=m.Name(value="datetime"), attr=m.Name(value="utcnow")
			),
			args=[],
	)
	datetime_datetime_utcnow_matcher = m.Call(
			func=m.Attribute(
					value=m.Attribute(
							value=m.Name(value="datetime"), attr=m.Name(value="datetime")
					),
					attr=m.Name(value="utcnow"),
			),
			args=[],
	)
	
	datetime_replace_matcher = m.Call(
			func=m.Attribute(
					value=datetime_utcnow_matcher, attr=m.Name(value="replace")
			),
			args=[m.OneOf(timezone_utc_matcher, utc_matcher)],
	)
	datetime_datetime_replace_matcher = m.Call(
			func=m.Attribute(
					value=datetime_datetime_utcnow_matcher,
					attr=m.Name(value="replace"),
			),
			args=[m.OneOf(timezone_utc_matcher, utc_matcher)],
	)
	
	timedelta_replace_matcher = m.Call(
			func=m.Attribute(
					value=m.BinaryOperation(
							left=m.OneOf(
									datetime_utcnow_matcher, datetime_datetime_utcnow_matcher
							),
							operator=m.Add(),
					),
					attr=m.Name(value="replace"),
			),
			args=[m.OneOf(timezone_utc_matcher, utc_matcher)],
	)
	
	utc_localize_matcher = m.Call(
			func=m.Attribute(
					value=m.Name(value="UTC"), attr=m.Name(value="localize"),
			),
			args=[
					m.Arg(
							value=m.OneOf(
									datetime_utcnow_matcher, datetime_datetime_utcnow_matcher
							)
					)
			],
	)
	
	def _update_imports(self):
		RemoveImportsVisitor.remove_unused_import(self.context, "pytz")
		RemoveImportsVisitor.remove_unused_import(self.context, "pytz", "utc")
		RemoveImportsVisitor.remove_unused_import(self.context, "pytz", "UTC")
		RemoveImportsVisitor.remove_unused_import(
				self.context, "datetime", "timezone"
		)
		AddImportsVisitor.add_needed_import(
				self.context, "bulb.platform.common.timezones", "UTC"
		)
	
	@m.leave(datetime_utcnow_matcher)
	def datetime_utcnow_call(
			self, original_node: cst.Call, updated_node: cst.Call
	) -> cst.Call:
		self._update_imports()
		
		return updated_node.with_changes(
				func=cst.Attribute(
						value=cst.Name(value="datetime"), attr=cst.Name("now")
				),
				args=[cst.Arg(value=cst.Name(value="UTC"))],
		)
	
	@m.leave(datetime_datetime_utcnow_matcher)
	def datetime_datetime_utcnow_call(
			self, original_node: cst.Call, updated_node: cst.Call
	) -> cst.Call:
		self._update_imports()
		
		return updated_node.with_changes(
				func=cst.Attribute(
						value=cst.Attribute(
								value=cst.Name(value="datetime"),
								attr=cst.Name(value="datetime"),
						),
						attr=cst.Name(value="now"),
				),
				args=[cst.Arg(value=cst.Name(value="UTC"))],
		)
	
	@m.leave(datetime_replace_matcher)
	def datetime_replace(
			self, original_node: cst.Call, updated_node: cst.Call
	) -> cst.Call:
		self._update_imports()
		
		return updated_node.with_changes(
				func=cst.Attribute(
						value=cst.Name(value="datetime"), attr=cst.Name("now")
				),
				args=[cst.Arg(value=cst.Name(value="UTC"))],
		)
	
	@m.leave(datetime_datetime_replace_matcher)
	def datetime_datetime_replace(
			self, original_node: cst.Call, updated_node: cst.Call
	) -> cst.Call:
		self._update_imports()
		
		return updated_node.with_changes(
				func=cst.Attribute(
						value=cst.Attribute(
								value=cst.Name(value="datetime"),
								attr=cst.Name(value="datetime"),
						),
						attr=cst.Name(value="now"),
				),
				args=[cst.Arg(value=cst.Name(value="UTC"))],
		)
	
	@m.leave(timedelta_replace_matcher)
	def timedelta_replace(
			self, original_node: cst.Call, updated_node: cst.Call
	) -> cst.BinaryOperation:
		self._update_imports()
		
		return cast(
				cst.BinaryOperation,
				cast(cst.Attribute, cast(cst.Call, updated_node).func).value,
		)
	
	@m.leave(utc_localize_matcher)
	def utc_localize(
			self, original_node: cst.Call, updated_node: cst.Call
	) -> cst.Call:
		self._update_imports()
		
		return cast(cst.Call, updated_node.args[0].value)