Ejemplo n.º 1
0
    def test_extract_simple(self) -> None:
        # Verify true behavior
        expression = cst.parse_expression("a + b[c], d(e, f * g)")
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(
                        left=m.SaveMatchedNode(m.Name(), "left"))),
                m.Element(m.Call()),
            ]),
        )
        extracted_node = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[0].value,
            cst.BinaryOperation,
        ).left
        self.assertEqual(nodes, {"left": extracted_node})

        # Verify false behavior
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(
                        left=m.SaveMatchedNode(m.Subscript(), "left"))),
                m.Element(m.Call()),
            ]),
        )
        self.assertIsNone(nodes)
Ejemplo n.º 2
0
    def test_predicate_logic_operators_on_attributes(self) -> None:
        # Verify that we can or things together.
        matcher = m.BinaryOperation(left=m.Name(
            metadata=m.MatchMetadata(meta.PositionProvider,
                                     self._make_coderange((1, 0), (1, 1)))
            | m.MatchMetadata(meta.PositionProvider,
                              self._make_coderange((1, 0), (1, 2)))))
        node, wrapper = self._make_fixture("a + b")
        self.assertTrue(matches(node, matcher, metadata_resolver=wrapper))
        matcher = m.BinaryOperation(left=m.Integer(
            metadata=m.MatchMetadata(meta.PositionProvider,
                                     self._make_coderange((1, 0), (1, 1)))
            | m.MatchMetadata(meta.PositionProvider,
                              self._make_coderange((1, 0), (1, 2)))))
        node, wrapper = self._make_fixture("12 + 3")
        self.assertTrue(matches(node, matcher, metadata_resolver=wrapper))
        node, wrapper = self._make_fixture("123 + 4")
        self.assertFalse(matches(node, matcher, metadata_resolver=wrapper))

        # Verify that we can and things together
        matcher = m.BinaryOperation(left=m.Name(
            metadata=m.MatchMetadata(meta.PositionProvider,
                                     self._make_coderange((1, 0), (1, 1)))
            & m.MatchMetadata(meta.ExpressionContextProvider,
                              meta.ExpressionContext.LOAD)))
        node, wrapper = self._make_fixture("a + b")
        self.assertTrue(matches(node, matcher, metadata_resolver=wrapper))
        node, wrapper = self._make_fixture("ab + cd")
        self.assertFalse(matches(node, matcher, metadata_resolver=wrapper))

        # Verify that we can not things
        matcher = m.BinaryOperation(left=m.Name(metadata=~(m.MatchMetadata(
            meta.ExpressionContextProvider, meta.ExpressionContext.STORE))))
        node, wrapper = self._make_fixture("a + b")
        self.assertTrue(matches(node, matcher, metadata_resolver=wrapper))
Ejemplo n.º 3
0
    def test_predicate_logic(self) -> None:
        # Verify that we can or things together.
        matcher = m.BinaryOperation(left=m.OneOf(
            m.MatchMetadata(meta.PositionProvider,
                            self._make_coderange((1, 0), (1, 1))),
            m.MatchMetadata(meta.PositionProvider,
                            self._make_coderange((1, 0), (1, 2))),
        ))
        node, wrapper = self._make_fixture("a + b")
        self.assertTrue(matches(node, matcher, metadata_resolver=wrapper))
        node, wrapper = self._make_fixture("12 + 3")
        self.assertTrue(matches(node, matcher, metadata_resolver=wrapper))
        node, wrapper = self._make_fixture("123 + 4")
        self.assertFalse(matches(node, matcher, metadata_resolver=wrapper))

        # Verify that we can and things together
        matcher = m.BinaryOperation(left=m.AllOf(
            m.MatchMetadata(meta.PositionProvider,
                            self._make_coderange((1, 0), (1, 1))),
            m.MatchMetadata(meta.ExpressionContextProvider,
                            meta.ExpressionContext.LOAD),
        ))
        node, wrapper = self._make_fixture("a + b")
        self.assertTrue(matches(node, matcher, metadata_resolver=wrapper))
        node, wrapper = self._make_fixture("ab + cd")
        self.assertFalse(matches(node, matcher, metadata_resolver=wrapper))

        # Verify that we can not things
        matcher = m.BinaryOperation(left=m.DoesNotMatch(
            m.MatchMetadata(meta.ExpressionContextProvider,
                            meta.ExpressionContext.STORE)))
        node, wrapper = self._make_fixture("a + b")
        self.assertTrue(matches(node, matcher, metadata_resolver=wrapper))
Ejemplo n.º 4
0
    def test_extract_metadata(self) -> None:
        # Verify true behavior
        module = cst.parse_module("a + b[c], d(e, f * g)")
        wrapper = cst.MetadataWrapper(module)
        expression = cst.ensure_type(
            cst.ensure_type(wrapper.module.body[0],
                            cst.SimpleStatementLine).body[0],
            cst.Expr,
        ).value

        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(left=m.Name(metadata=m.SaveMatchedNode(
                        m.MatchMetadata(
                            meta.PositionProvider,
                            self._make_coderange((1, 0), (1, 1)),
                        ),
                        "left",
                    )))),
                m.Element(m.Call()),
            ]),
            metadata_resolver=wrapper,
        )
        extracted_node = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[0].value,
            cst.BinaryOperation,
        ).left
        self.assertEqual(nodes, {"left": extracted_node})

        # Verify false behavior
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(left=m.Name(metadata=m.SaveMatchedNode(
                        m.MatchMetadata(
                            meta.PositionProvider,
                            self._make_coderange((1, 0), (1, 2)),
                        ),
                        "left",
                    )))),
                m.Element(m.Call()),
            ]),
            metadata_resolver=wrapper,
        )
        self.assertIsNone(nodes)
Ejemplo n.º 5
0
    def test_extract_predicates(self) -> None:
        expression = cst.parse_expression("a + b[c], d(e, f * g)")
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(
                        left=m.SaveMatchedNode(m.Name(), "left"))),
                m.Element(
                    m.Call(func=m.SaveMatchedNode(m.Name(), "func")
                           | m.SaveMatchedNode(m.Attribute(), "attr"))),
            ]),
        )
        extracted_node_left = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[0].value,
            cst.BinaryOperation,
        ).left
        extracted_node_func = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[1].value,
            cst.Call).func
        self.assertEqual(nodes, {
            "left": extracted_node_left,
            "func": extracted_node_func
        })

        expression = cst.parse_expression("a + b[c], d.z(e, f * g)")
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(
                        left=m.SaveMatchedNode(m.Name(), "left"))),
                m.Element(
                    m.Call(func=m.SaveMatchedNode(m.Name(), "func")
                           | m.SaveMatchedNode(m.Attribute(), "attr"))),
            ]),
        )
        extracted_node_left = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[0].value,
            cst.BinaryOperation,
        ).left
        extracted_node_attr = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[1].value,
            cst.Call).func
        self.assertEqual(nodes, {
            "left": extracted_node_left,
            "attr": extracted_node_attr
        })
Ejemplo n.º 6
0
 def _has_none(node):
     if m.matches(node, m.Name("None")):
         return True
     elif m.matches(node, m.BinaryOperation()):
         return _has_none(node.left) or _has_none(node.right)
     else:
         return False
Ejemplo n.º 7
0
 def test_simple_matcher_false(self) -> None:
     # Fail to match on a simple node based on the type and the position.
     node, wrapper = self._make_fixture("foo")
     self.assertFalse(
         matches(
             node,
             m.Name(
                 value="foo",
                 metadata=m.MatchMetadata(
                     meta.SyntacticPositionProvider,
                     self._make_coderange((2, 0), (2, 3)),
                 ),
             ),
             metadata_resolver=wrapper,
         )
     )
     # Fail to match on any binary expression where the two children are in exact spots.
     node, wrapper = self._make_fixture("foo + bar")
     self.assertFalse(
         matches(
             node,
             m.BinaryOperation(
                 left=m.MatchMetadata(
                     meta.SyntacticPositionProvider,
                     self._make_coderange((1, 0), (1, 1)),
                 ),
                 right=m.MatchMetadata(
                     meta.SyntacticPositionProvider,
                     self._make_coderange((1, 4), (1, 5)),
                 ),
             ),
             metadata_resolver=wrapper,
         )
     )
Ejemplo n.º 8
0
    def leave_BinaryOperation(
        self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation
    ) -> cst.BaseExpression:
        expr_key = "expr"
        extracts = m.extract(
            original_node,
            m.BinaryOperation(
                left=m.MatchIfTrue(_match_simple_string),
                operator=m.Modulo(),
                right=m.SaveMatchedNode(
                    m.MatchIfTrue(_gen_match_simple_expression(self.module)), expr_key,
                ),
            ),
        )

        if extracts:
            expr = extracts[expr_key]
            parts = []
            simple_string = cst.ensure_type(original_node.left, cst.SimpleString)
            innards = simple_string.raw_value.replace("{", "{{").replace("}", "}}")
            tokens = innards.split("%s")
            token = tokens[0]
            if len(token) > 0:
                parts.append(cst.FormattedStringText(value=token))
            expressions = (
                [elm.value for elm in expr.elements]
                if isinstance(expr, cst.Tuple)
                else [expr]
            )
            escape_transformer = EscapeStringQuote(simple_string.quote)
            i = 1
            while i < len(tokens):
                if i - 1 >= len(expressions):
                    # the %-string doesn't come with same number of elements in tuple
                    return original_node
                try:
                    parts.append(
                        cst.FormattedStringExpression(
                            expression=cast(
                                cst.BaseExpression,
                                expressions[i - 1].visit(escape_transformer),
                            )
                        )
                    )
                except Exception:
                    return original_node
                token = tokens[i]
                if len(token) > 0:
                    parts.append(cst.FormattedStringText(value=token))
                i += 1
            start = f"f{simple_string.prefix}{simple_string.quote}"
            return cst.FormattedString(
                parts=parts, start=start, end=simple_string.quote
            )

        return original_node
Ejemplo n.º 9
0
 def visit_BinaryOperation(self, node: cst.BinaryOperation) -> None:
     if (m.matches(
             node,
             m.BinaryOperation(left=m.SimpleString(), operator=m.Modulo()))
             # SimpleString can be bytes and fstring don't support bytes.
             # https://www.python.org/dev/peps/pep-0498/#no-binary-f-strings
             and isinstance(
                 cst.ensure_type(node.left,
                                 cst.SimpleString).evaluated_value, str)):
         self.report(node)
Ejemplo n.º 10
0
 def visit_BinaryOperation(self, node: cst.BinaryOperation) -> None:
     if not self.logging_stack:
         return
     if m.matches(
         node,
         m.BinaryOperation(
             left=m.OneOf(m.SimpleString(), m.ConcatenatedString()),
             operator=m.Modulo(),
         ),
     ):
         self.report(node)
Ejemplo n.º 11
0
 def test_extract_tautology(self) -> None:
     expression = cst.parse_expression("a + b[c], d(e, f * g)")
     nodes = m.extract(
         expression,
         m.SaveMatchedNode(
             m.Tuple(elements=[
                 m.Element(m.BinaryOperation()),
                 m.Element(m.Call())
             ]),
             name="node",
         ),
     )
     self.assertEqual(nodes, {"node": expression})
Ejemplo n.º 12
0
    def visit_BinaryOperation(self, node: cst.BinaryOperation) -> None:
        expr_key = "expr"
        extracts = m.extract(
            node,
            m.BinaryOperation(
                left=m.MatchIfTrue(_match_simple_string),
                operator=m.Modulo(),
                right=m.SaveMatchedNode(
                    m.MatchIfTrue(
                        _gen_match_simple_expression(
                            self.context.wrapper.module)),
                    expr_key,
                ),
            ),
        )

        if extracts:
            expr = extracts[expr_key]
            parts = []
            simple_string = cst.ensure_type(node.left, cst.SimpleString)
            innards = simple_string.raw_value.replace("{",
                                                      "{{").replace("}", "}}")
            tokens = innards.split("%s")
            token = tokens[0]
            if len(token) > 0:
                parts.append(cst.FormattedStringText(value=token))
            expressions = ([elm.value for elm in expr.elements] if isinstance(
                expr, cst.Tuple) else [expr])
            escape_transformer = EscapeStringQuote(simple_string.quote)
            i = 1
            while i < len(tokens):
                if i - 1 >= len(expressions):
                    # Only generate warning for cases where %-string not comes with same number of elements in tuple
                    self.report(node)
                    return
                try:
                    parts.append(
                        cst.FormattedStringExpression(expression=cast(
                            cst.BaseExpression,
                            expressions[i - 1].visit(escape_transformer),
                        )))
                except Exception:
                    self.report(node)
                    return
                token = tokens[i]
                if len(token) > 0:
                    parts.append(cst.FormattedStringText(value=token))
                i += 1
            start = f"f{simple_string.prefix}{simple_string.quote}"
            replacement = cst.FormattedString(parts=parts,
                                              start=start,
                                              end=simple_string.quote)
            self.report(node, replacement=replacement)
        elif m.matches(
                node,
                m.BinaryOperation(
                    left=m.SimpleString(),
                    operator=m.Modulo())) and isinstance(
                        cst.ensure_type(
                            node.left, cst.SimpleString).evaluated_value, str):
            self.report(node)
Ejemplo n.º 13
0
class ShedFixers(VisitorBasedCodemodCommand):
    """Fix a variety of small problems.

    Replaces `raise NotImplemented` with `raise NotImplementedError`,
    and converts always-failing assert statements to explicit `raise` statements.

    Also includes code closely modelled on pybetter's fixers, because it's
    considerably faster to run all transforms in a single pass if possible.
    """

    DESCRIPTION = "Fix a variety of style, performance, and correctness issues."

    @m.call_if_inside(m.Raise(exc=m.Name(value="NotImplemented")))
    def leave_Name(self, _, updated_node):  # noqa
        return updated_node.with_changes(value="NotImplementedError")

    def leave_Assert(self, _, updated_node):  # noqa
        test_code = cst.Module("").code_for_node(updated_node.test)
        try:
            test_literal = literal_eval(test_code)
        except Exception:
            return updated_node
        if test_literal:
            return cst.RemovalSentinel.REMOVE
        if updated_node.msg is None:
            return cst.Raise(cst.Name("AssertionError"))
        return cst.Raise(
            cst.Call(cst.Name("AssertionError"),
                     args=[cst.Arg(updated_node.msg)]))

    @m.leave(
        m.ComparisonTarget(comparator=oneof_names("None", "False", "True"),
                           operator=m.Equal()))
    def convert_none_cmp(self, _, updated_node):
        """Inspired by Pybetter."""
        return updated_node.with_changes(operator=cst.Is())

    @m.leave(
        m.UnaryOperation(
            operator=m.Not(),
            expression=m.Comparison(
                comparisons=[m.ComparisonTarget(operator=m.In())]),
        ))
    def replace_not_in_condition(self, _, updated_node):
        """Also inspired by Pybetter."""
        expr = cst.ensure_type(updated_node.expression, cst.Comparison)
        return cst.Comparison(
            left=expr.left,
            lpar=updated_node.lpar,
            rpar=updated_node.rpar,
            comparisons=[
                expr.comparisons[0].with_changes(operator=cst.NotIn())
            ],
        )

    @m.leave(
        m.Call(
            lpar=[m.AtLeastN(n=1, matcher=m.LeftParen())],
            rpar=[m.AtLeastN(n=1, matcher=m.RightParen())],
        ))
    def remove_pointless_parens_around_call(self, _, updated_node):
        # This is *probably* valid, but we might have e.g. a multi-line parenthesised
        # chain of attribute accesses ("fluent interface"), where we need the parens.
        noparens = updated_node.with_changes(lpar=[], rpar=[])
        try:
            compile(self.module.code_for_node(noparens), "<string>", "eval")
            return noparens
        except SyntaxError:
            return updated_node

    # The following methods fix https://pypi.org/project/flake8-comprehensions/

    @m.leave(m.Call(func=m.Name("list"), args=[m.Arg(m.GeneratorExp())]))
    def replace_generator_in_call_with_comprehension(self, _, updated_node):
        """Fix flake8-comprehensions C400-402 and 403-404.

        C400-402: Unnecessary generator - rewrite as a <list/set/dict> comprehension.
        Note that set and dict conversions are handled by pyupgrade!
        """
        return cst.ListComp(elt=updated_node.args[0].value.elt,
                            for_in=updated_node.args[0].value.for_in)

    @m.leave(
        m.Call(func=m.Name("list"), args=[m.Arg(m.ListComp(), star="")])
        | m.Call(func=m.Name("set"), args=[m.Arg(m.SetComp(), star="")])
        | m.Call(
            func=m.Name("list"),
            args=[m.Arg(m.Call(func=oneof_names("sorted", "list")), star="")],
        ))
    def replace_unnecessary_list_around_sorted(self, _, updated_node):
        """Fix flake8-comprehensions C411 and C413.

        Unnecessary <list/reversed> call around sorted().

        Also covers C411 Unnecessary list call around list comprehension
        for lists and sets.
        """
        return updated_node.args[0].value

    @m.leave(
        m.Call(
            func=m.Name("reversed"),
            args=[m.Arg(m.Call(func=m.Name("sorted")), star="")],
        ))
    def replace_unnecessary_reversed_around_sorted(self, _, updated_node):
        """Fix flake8-comprehensions C413.

        Unnecessary reversed call around sorted().
        """
        call = updated_node.args[0].value
        args = list(call.args)
        for i, arg in enumerate(args):
            if m.matches(arg.keyword, m.Name("reverse")):
                try:
                    val = bool(
                        literal_eval(self.module.code_for_node(arg.value)))
                except Exception:
                    args[i] = arg.with_changes(
                        value=cst.UnaryOperation(cst.Not(), arg.value))
                else:
                    if not val:
                        args[i] = arg.with_changes(value=cst.Name("True"))
                    else:
                        del args[i]
                        args[i - 1] = remove_trailing_comma(args[i - 1])
                break
        else:
            args.append(
                cst.Arg(keyword=cst.Name("reverse"), value=cst.Name("True")))
        return call.with_changes(args=args)

    _sets = oneof_names("set", "frozenset")
    _seqs = oneof_names("list", "reversed", "sorted", "tuple")

    @m.leave(
        m.Call(func=_sets, args=[m.Arg(m.Call(func=_sets | _seqs), star="")])
        | m.Call(
            func=oneof_names("list", "tuple"),
            args=[m.Arg(m.Call(func=oneof_names("list", "tuple")), star="")],
        )
        | m.Call(
            func=m.Name("sorted"),
            args=[m.Arg(m.Call(func=_seqs), star=""),
                  m.ZeroOrMore()],
        ))
    def replace_unnecessary_nested_calls(self, _, updated_node):
        """Fix flake8-comprehensions C414.

        Unnecessary <list/reversed/sorted/tuple> call within <list/set/sorted/tuple>()..
        """
        return updated_node.with_changes(
            args=[cst.Arg(updated_node.args[0].value.args[0].value)] +
            list(updated_node.args[1:]), )

    @m.leave(
        m.Call(
            func=oneof_names("reversed", "set", "sorted"),
            args=[
                m.Arg(m.Subscript(slice=[m.SubscriptElement(ALL_ELEMS_SLICE)]))
            ],
        ))
    def replace_unnecessary_subscript_reversal(self, _, updated_node):
        """Fix flake8-comprehensions C415.

        Unnecessary subscript reversal of iterable within <reversed/set/sorted>().
        """
        return updated_node.with_changes(
            args=[cst.Arg(updated_node.args[0].value.value)], )

    @m.leave(
        multi(
            m.ListComp,
            m.SetComp,
            elt=m.Name(),
            for_in=m.CompFor(target=m.Name(),
                             ifs=[],
                             inner_for_in=None,
                             asynchronous=None),
        ))
    def replace_unnecessary_listcomp_or_setcomp(self, _, updated_node):
        """Fix flake8-comprehensions C416.

        Unnecessary <list/set> comprehension - rewrite using <list/set>().
        """
        if updated_node.elt.value == updated_node.for_in.target.value:
            func = cst.Name(
                "list" if isinstance(updated_node, cst.ListComp) else "set")
            return cst.Call(func=func,
                            args=[cst.Arg(updated_node.for_in.iter)])
        return updated_node

    @m.leave(m.Subscript(oneof_names("Union", "Literal")))
    def reorder_union_literal_contents_none_last(self, _, updated_node):
        subscript = list(updated_node.slice)
        try:
            subscript.sort(key=lambda elt: elt.slice.value.value == "None")
            subscript[-1] = remove_trailing_comma(subscript[-1])
            return updated_node.with_changes(slice=subscript)
        except Exception:  # Single-element literals are not slices, etc.
            return updated_node

    @m.call_if_inside(m.Annotation(annotation=m.BinaryOperation()))
    @m.leave(
        m.BinaryOperation(
            left=m.Name("None") | m.BinaryOperation(),
            operator=m.BitOr(),
            right=m.DoNotCare(),
        ))
    def reorder_union_operator_contents_none_last(self, _, updated_node):
        def _has_none(node):
            if m.matches(node, m.Name("None")):
                return True
            elif m.matches(node, m.BinaryOperation()):
                return _has_none(node.left) or _has_none(node.right)
            else:
                return False

        node_left = updated_node.left
        if _has_none(node_left):
            return updated_node.with_changes(left=updated_node.right,
                                             right=node_left)
        else:
            return updated_node

    @m.leave(m.Subscript(value=m.Name("Literal")))
    def flatten_literal_subscript(self, _, updated_node):
        new_slice = []
        for item in updated_node.slice:
            if m.matches(item.slice.value, m.Subscript(m.Name("Literal"))):
                new_slice += item.slice.value.slice
            else:
                new_slice.append(item)
        return updated_node.with_changes(slice=new_slice)

    @m.leave(m.Subscript(value=m.Name("Union")))
    def flatten_union_subscript(self, _, updated_node):
        new_slice = []
        has_none = False
        for item in updated_node.slice:
            if m.matches(item.slice.value, m.Subscript(m.Name("Optional"))):
                new_slice += item.slice.value.slice  # peel off "Optional"
                has_none = True
            elif m.matches(item.slice.value,
                           m.Subscript(m.Name("Union"))) and m.matches(
                               updated_node.value, item.slice.value.value):
                new_slice += item.slice.value.slice  # peel off "Union" or "Literal"
            elif m.matches(item.slice.value, m.Name("None")):
                has_none = True
            else:
                new_slice.append(item)
        if has_none:
            new_slice.append(
                cst.SubscriptElement(slice=cst.Index(cst.Name("None"))))
        return updated_node.with_changes(slice=new_slice)

    @m.leave(m.Else(m.IndentedBlock([m.SimpleStatementLine([m.Pass()])])))
    def discard_empty_else_blocks(self, _, updated_node):
        # An `else: pass` block can always simply be discarded, and libcst ensures
        # that an Else node can only ever occur attached to an If, While, For, or Try
        # node; in each case `None` is the valid way to represent "no else block".
        if m.findall(updated_node, m.Comment()):
            return updated_node  # If there are any comments, keep the node
        return cst.RemoveFromParent()

    @m.leave(
        m.Lambda(params=m.MatchIfTrue(lambda node: (
            node.star_kwarg is None and not node.kwonly_params and not node.
            posonly_params and isinstance(node.star_arg, cst.MaybeSentinel) and
            all(param.default is None for param in node.params)))))
    def remove_lambda_indirection(self, _, updated_node):
        same_args = [
            m.Arg(m.Name(param.name.value), star="", keyword=None)
            for param in updated_node.params.params
        ]
        if m.matches(updated_node.body, m.Call(args=same_args)):
            return cst.ensure_type(updated_node.body, cst.Call).func
        return updated_node

    @m.leave(
        m.BooleanOperation(
            left=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]),
            operator=m.Or(),
            right=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]),
        ))
    def collapse_isinstance_checks(self, _, updated_node):
        left_target, left_type = updated_node.left.args
        right_target, right_type = updated_node.right.args
        if left_target.deep_equals(right_target):
            merged_type = cst.Arg(
                cst.Tuple([
                    cst.Element(left_type.value),
                    cst.Element(right_type.value)
                ]))
            return updated_node.left.with_changes(
                args=[left_target, merged_type])
        return updated_node
Ejemplo n.º 14
0
class Checker(m.MatcherDecoratableVisitor):
    METADATA_DEPENDENCIES = (PositionProvider,)

    def __init__(
        self, path: Path, verbose: bool = False, ignored: Optional[List[str]] = None
    ):
        super().__init__()
        self.path = path
        self.verbose = verbose
        self.ignored = set(ignored or [])
        self.future_division = False
        self.errors = False
        self.stack: List[str] = []

    @m.call_if_inside(m.ImportFrom(module=m.Name("__future__")))
    @m.visit(m.ImportAlias(name=m.Name("division")))
    def import_div(self, node: ImportAlias) -> None:
        self.future_division = True

    @m.visit(m.BinaryOperation(operator=m.Divide()))
    def check_div(self, node: BinaryOperation) -> None:
        if "division" in self.ignored:
            return
        if not self.future_division:
            pos = self.get_metadata(PositionProvider, node).start
            print(
                f"{self.path}:{pos.line}:{pos.column}: division without `from __future__ import division`"
            )
            self.errors = True

    @m.visit(m.Attribute(attr=m.Name("maxint"), value=m.Name("sys")))
    def check_maxint(self, node: Attribute) -> None:
        if "sys.maxint" in self.ignored:
            return
        pos = self.get_metadata(PositionProvider, node).start
        print(f"{self.path}:{pos.line}:{pos.column}: use of sys.maxint")
        self.errors = True

    def visit_ClassDef(self, node: ClassDef) -> None:
        self.stack.append(node.name.value)

    def leave_ClassDef(self, node: ClassDef) -> None:
        self.stack.pop()

    def visit_FunctionDef(self, node: FunctionDef) -> None:
        self.stack.append(node.name.value)

    def leave_FunctionDef(self, node: FunctionDef) -> None:
        self.stack.pop()

    def visit_ClassDef_bases(self, node: "ClassDef") -> None:
        return

    @m.visit(
        m.Call(
            func=m.Attribute(attr=m.Name("assertEquals") | m.Name("assertItemsEqual"))
        )
    )
    def visit_old_assert(self, node: Call) -> None:
        name = ensure_type(node.func, Attribute).attr.value
        if name in self.ignored:
            return
        pos = self.get_metadata(PositionProvider, node).start
        print(f"{self.path}:{pos.line}:{pos.column}: use of {name}")
        self.errors = True
Ejemplo n.º 15
0
class DatetimeUtcnow_(VisitorBasedCodemodCommand):
	
	DESCRIPTION: str = "Converts from datetime.utcnow() to datetime.utc()"
	
	timezone_utc_matcher = m.Arg(
			value=m.Attribute(
					value=m.Name(value="timezone"), attr=m.Name(value="utc")
			),
			keyword=m.Name(value="tzinfo"),
	)
	
	utc_matcher = m.Arg(
			value=m.OneOf(
					m.Name(value="utc"),
					m.Name(value="UTC"),
					m.Attribute(value=m.Name(value="pytz",), attr=m.Name(value="UTC")),
			),
			keyword=m.Name(value="tzinfo"),
	)
	
	datetime_utcnow_matcher = m.Call(
			func=m.Attribute(
					value=m.Name(value="datetime"), attr=m.Name(value="utcnow")
			),
			args=[],
	)
	datetime_datetime_utcnow_matcher = m.Call(
			func=m.Attribute(
					value=m.Attribute(
							value=m.Name(value="datetime"), attr=m.Name(value="datetime")
					),
					attr=m.Name(value="utcnow"),
			),
			args=[],
	)
	
	datetime_replace_matcher = m.Call(
			func=m.Attribute(
					value=datetime_utcnow_matcher, attr=m.Name(value="replace")
			),
			args=[m.OneOf(timezone_utc_matcher, utc_matcher)],
	)
	datetime_datetime_replace_matcher = m.Call(
			func=m.Attribute(
					value=datetime_datetime_utcnow_matcher,
					attr=m.Name(value="replace"),
			),
			args=[m.OneOf(timezone_utc_matcher, utc_matcher)],
	)
	
	timedelta_replace_matcher = m.Call(
			func=m.Attribute(
					value=m.BinaryOperation(
							left=m.OneOf(
									datetime_utcnow_matcher, datetime_datetime_utcnow_matcher
							),
							operator=m.Add(),
					),
					attr=m.Name(value="replace"),
			),
			args=[m.OneOf(timezone_utc_matcher, utc_matcher)],
	)
	
	utc_localize_matcher = m.Call(
			func=m.Attribute(
					value=m.Name(value="UTC"), attr=m.Name(value="localize"),
			),
			args=[
					m.Arg(
							value=m.OneOf(
									datetime_utcnow_matcher, datetime_datetime_utcnow_matcher
							)
					)
			],
	)
	
	def _update_imports(self):
		RemoveImportsVisitor.remove_unused_import(self.context, "pytz")
		RemoveImportsVisitor.remove_unused_import(self.context, "pytz", "utc")
		RemoveImportsVisitor.remove_unused_import(self.context, "pytz", "UTC")
		RemoveImportsVisitor.remove_unused_import(
				self.context, "datetime", "timezone"
		)
		AddImportsVisitor.add_needed_import(
				self.context, "bulb.platform.common.timezones", "UTC"
		)
	
	@m.leave(datetime_utcnow_matcher)
	def datetime_utcnow_call(
			self, original_node: cst.Call, updated_node: cst.Call
	) -> cst.Call:
		self._update_imports()
		
		return updated_node.with_changes(
				func=cst.Attribute(
						value=cst.Name(value="datetime"), attr=cst.Name("now")
				),
				args=[cst.Arg(value=cst.Name(value="UTC"))],
		)
	
	@m.leave(datetime_datetime_utcnow_matcher)
	def datetime_datetime_utcnow_call(
			self, original_node: cst.Call, updated_node: cst.Call
	) -> cst.Call:
		self._update_imports()
		
		return updated_node.with_changes(
				func=cst.Attribute(
						value=cst.Attribute(
								value=cst.Name(value="datetime"),
								attr=cst.Name(value="datetime"),
						),
						attr=cst.Name(value="now"),
				),
				args=[cst.Arg(value=cst.Name(value="UTC"))],
		)
	
	@m.leave(datetime_replace_matcher)
	def datetime_replace(
			self, original_node: cst.Call, updated_node: cst.Call
	) -> cst.Call:
		self._update_imports()
		
		return updated_node.with_changes(
				func=cst.Attribute(
						value=cst.Name(value="datetime"), attr=cst.Name("now")
				),
				args=[cst.Arg(value=cst.Name(value="UTC"))],
		)
	
	@m.leave(datetime_datetime_replace_matcher)
	def datetime_datetime_replace(
			self, original_node: cst.Call, updated_node: cst.Call
	) -> cst.Call:
		self._update_imports()
		
		return updated_node.with_changes(
				func=cst.Attribute(
						value=cst.Attribute(
								value=cst.Name(value="datetime"),
								attr=cst.Name(value="datetime"),
						),
						attr=cst.Name(value="now"),
				),
				args=[cst.Arg(value=cst.Name(value="UTC"))],
		)
	
	@m.leave(timedelta_replace_matcher)
	def timedelta_replace(
			self, original_node: cst.Call, updated_node: cst.Call
	) -> cst.BinaryOperation:
		self._update_imports()
		
		return cast(
				cst.BinaryOperation,
				cast(cst.Attribute, cast(cst.Call, updated_node).func).value,
		)
	
	@m.leave(utc_localize_matcher)
	def utc_localize(
			self, original_node: cst.Call, updated_node: cst.Call
	) -> cst.Call:
		self._update_imports()
		
		return cast(cst.Call, updated_node.args[0].value)
Ejemplo n.º 16
0
    def obf_universal(self, node: cst.CSTNode, *types):

        if m.matches(node, m.Name()):
            types = ('a', 'ca', 'v', 'cv') if not types else types
            node = cst.ensure_type(node, cst.Name)
            if self.can_rename(node.value, *types):
                node = self.get_new_cst_name(node)

        elif m.matches(node, m.NameItem()):
            node = cst.ensure_type(node, cst.NameItem)
            node = node.with_changes(name=self.obf_universal(node.name))

        elif m.matches(node, m.Call()):

            node = cst.ensure_type(node, cst.Call)
            if self.change_methods or self.change_functions:
                node = self.new_obf_function_name(node)
            if self.change_arguments or self.change_method_arguments:
                node = self.obf_function_args(node)

        elif m.matches(node, m.Attribute()):
            node = cst.ensure_type(node, cst.Attribute)
            value = node.value
            attr = node.attr

            self.obf_universal(value)
            self.obf_universal(attr)

        elif m.matches(node, m.AssignTarget()):
            node = cst.ensure_type(node, cst.AssignTarget)
            node = node.with_changes(target=self.obf_universal(node.target))

        elif m.matches(node, m.List() | m.Tuple()):
            node = cst.ensure_type(node, cst.List) if m.matches(
                node, m.List()) else cst.ensure_type(node, cst.Tuple)
            new_elements = []
            for el in node.elements:
                new_elements.append(self.obf_universal(el))
            node = node.with_changes(elements=new_elements)
        elif m.matches(node, m.Subscript()):
            node = cst.ensure_type(node, cst.Subscript)
            new_slice = []
            for el in node.slice:
                new_slice.append(
                    el.with_changes(slice=self.obf_slice(el.slice)))
            node = node.with_changes(slice=new_slice)
            node = node.with_changes(value=self.obf_universal(node.value))
        elif m.matches(node, m.Element()):
            node = cst.ensure_type(node, cst.Element)
            node = node.with_changes(value=self.obf_universal(node.value))

        elif m.matches(node, m.Dict()):
            node = cst.ensure_type(node, cst.Dict)
            new_elements = []
            for el in node.elements:
                new_elements.append(self.obf_universal(el))
            node = node.with_changes(elements=new_elements)
        elif m.matches(node, m.DictElement()):
            node = cst.ensure_type(node, cst.DictElement)
            new_key = self.obf_universal(node.key)
            new_val = self.obf_universal(node.value)
            node = node.with_changes(key=new_key, value=new_val)
        elif m.matches(node, m.StarredDictElement()):
            node = cst.ensure_type(node, cst.StarredDictElement)
            node = node.with_changes(value=self.obf_universal(node.value))

        elif m.matches(node, m.If() | m.While()):
            node = cst.ensure_type(node, cst.IfExp) if m.matches(
                node, cst.If
                | cst.IfExp) else cst.ensure_type(node, cst.While)
            node = node.with_changes(test=self.obf_universal(node.test))
        elif m.matches(node, m.IfExp()):
            node = cst.ensure_type(node, cst.IfExp)
            node = node.with_changes(body=self.obf_universal(node.body))
            node = node.with_changes(test=self.obf_universal(node.test))
            node = node.with_changes(orelse=self.obf_universal(node.orelse))

        elif m.matches(node, m.Comparison()):
            node = cst.ensure_type(node, cst.Comparison)
            new_compars = []
            for target in node.comparisons:
                new_compars.append(self.obf_universal(target))

            node = node.with_changes(left=self.obf_universal(node.left))
            node = node.with_changes(comparisons=new_compars)
        elif m.matches(node, m.ComparisonTarget()):
            node = cst.ensure_type(node, cst.ComparisonTarget)
            node = node.with_changes(
                comparator=self.obf_universal(node.comparator))

        elif m.matches(node, m.FormattedString()):
            node = cst.ensure_type(node, cst.FormattedString)
            new_parts = []
            for part in node.parts:
                new_parts.append(self.obf_universal(part))
            node = node.with_changes(parts=new_parts)
        elif m.matches(node, m.FormattedStringExpression()):
            node = cst.ensure_type(node, cst.FormattedStringExpression)
            node = node.with_changes(
                expression=self.obf_universal(node.expression))

        elif m.matches(node, m.BinaryOperation() | m.BooleanOperation()):
            node = cst.ensure_type(node, cst.BinaryOperation) if m.matches(
                node, m.BinaryOperation()) else cst.ensure_type(
                    node, cst.BooleanOperation)
            node = node.with_changes(left=self.obf_universal(node.left),
                                     right=self.obf_universal(node.right))
        elif m.matches(node, m.UnaryOperation()):
            node = cst.ensure_type(node, cst.UnaryOperation)
            node = node.with_changes(
                expression=self.obf_universal(node.expression))

        elif m.matches(node, m.ListComp()):
            node = cst.ensure_type(node, cst.ListComp)
            node = node.with_changes(elt=self.obf_universal(node.elt))
            node = node.with_changes(for_in=self.obf_universal(node.for_in))

        elif m.matches(node, m.DictComp()):
            node = cst.ensure_type(node, cst.DictComp)
            node = node.with_changes(key=self.obf_universal(node.key))
            node = node.with_changes(value=self.obf_universal(node.value))
            node = node.with_changes(for_in=self.obf_universal(node.for_in))

        elif m.matches(node, m.CompFor()):
            node = cst.ensure_type(node, cst.CompFor)
            new_ifs = []

            node = node.with_changes(target=self.obf_universal(node.target))
            node = node.with_changes(iter=self.obf_universal(node.iter))
            for el in node.ifs:
                new_ifs.append(self.obf_universal(el))
            node = node.with_changes(ifs=new_ifs)
        elif m.matches(node, m.CompIf()):
            node = cst.ensure_type(node, cst.CompIf)
            node = node.with_changes(test=self.obf_universal(node.test))

        elif m.matches(node, m.Integer() | m.Float() | m.SimpleString()):
            pass

        else:
            pass
            # print(node)

        return node