def alter_operator(
     self, original_op: Union[cst.Equal,
                              cst.NotEqual]) -> Union[cst.Is, cst.IsNot]:
     return (cst.IsNot(
         whitespace_before=original_op.whitespace_before,
         whitespace_after=original_op.whitespace_after,
     ) if isinstance(original_op, cst.NotEqual) else cst.Is(
         whitespace_before=original_op.whitespace_before,
         whitespace_after=original_op.whitespace_after,
     ))
示例#2
0
 def visit_BooleanOperation(self, node: cst.BooleanOperation) -> None:
     left_expression: cst.BaseExpression = node.left
     right_expression: cst.BaseExpression = node.right
     if m.matches(node.left, m.Name()):
         # Eg: "x and y".
         left_expression = cast(cst.Name, left_expression)
         if self._is_optional_type(left_expression):
             replacement_comparison = self._gen_comparison_to_none(
                 variable_name=left_expression.value, operator=cst.IsNot()
             )
             self.report(
                 node, replacement=node.with_changes(left=replacement_comparison)
             )
     if m.matches(right_expression, m.Name()):
         # Eg: "x and y".
         right_expression = cast(cst.Name, right_expression)
         if self._is_optional_type(right_expression):
             replacement_comparison = self._gen_comparison_to_none(
                 variable_name=right_expression.value, operator=cst.IsNot()
             )
             self.report(
                 node, replacement=node.with_changes(right=replacement_comparison)
             )
示例#3
0
    def replace_not_in_condition(
            self, _, updated_node: cst.UnaryOperation) -> cst.BaseExpression:
        comparison_node: cst.Comparison = cst.ensure_type(
            updated_node.expression, cst.Comparison)

        # TODO: Implement support for multiple consecutive 'not ... in B',
        # even if it does not make any sense in practice.
        return cst.Comparison(
            left=comparison_node.left,
            lpar=updated_node.lpar,
            rpar=updated_node.rpar,
            comparisons=[
                comparison_node.comparisons[0].with_changes(
                    operator=cst.IsNot())
            ],
        )
示例#4
0
    def leave_If(self, original_node: cst.If) -> None:
        changes: Dict[str, cst.CSTNode] = {}
        test_expression: cst.BaseExpression = original_node.test
        if m.matches(test_expression, m.Name()):
            # We are inside a simple check such as "if x".
            test_expression = cast(cst.Name, test_expression)
            if self._is_optional_type(test_expression):
                # We want to replace "if x" with "if x is not None".
                replacement_comparison: cst.Comparison = self._gen_comparison_to_none(
                    variable_name=test_expression.value, operator=cst.IsNot()
                )
                changes["test"] = replacement_comparison

        orelse = original_node.orelse
        if orelse is not None and m.matches(orelse, m.If()):
            # We want to catch this case upon leaving an `If` node so that we generate an `elif` statement correctly.
            # We check if the orelse node was reported, and if so, remove the report and generate a new report on
            # the current parent `If` node.
            new_reports = []
            orelse_report: Optional[CstLintRuleReport] = None
            for report in self.context.reports:
                if isinstance(report, CstLintRuleReport):
                    # Check whether the lint rule code matches this lint rule's code so we don't remove another
                    # lint rule's report.
                    if report.node is orelse and report.code == self.__class__.__name__:
                        orelse_report = report
                    else:
                        new_reports.append(report)
                else:
                    new_reports.append(report)
            if orelse_report is not None:
                self.context.reports = new_reports
                replacement_orelse = orelse_report.replacement_node
                changes["orelse"] = cst.ensure_type(replacement_orelse, cst.CSTNode)

        if changes:
            self.report(
                original_node, replacement=original_node.with_changes(**changes)
            )
示例#5
0
class ComparisonTest(CSTNodeTest):
    @data_provider((
        # Simple comparison statements
        (
            cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.LessThan(), cst.Integer("5")), ),
            ),
            "foo < 5",
        ),
        (
            cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.NotEqual(), cst.Integer("5")), ),
            ),
            "foo != 5",
        ),
        (
            cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.Is(), cst.Name("True")), )),
            "foo is True",
        ),
        (
            cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.IsNot(), cst.Name("False")), ),
            ),
            "foo is not False",
        ),
        (
            cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.In(), cst.Name("bar")), )),
            "foo in bar",
        ),
        (
            cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.NotIn(), cst.Name("bar")), ),
            ),
            "foo not in bar",
        ),
        # Comparison with parens
        (
            cst.Comparison(
                lpar=(cst.LeftParen(), ),
                left=cst.Name("foo"),
                comparisons=(cst.ComparisonTarget(
                    operator=cst.NotIn(), comparator=cst.Name("bar")), ),
                rpar=(cst.RightParen(), ),
            ),
            "(foo not in bar)",
        ),
        (
            cst.Comparison(
                left=cst.Name("a",
                              lpar=(cst.LeftParen(), ),
                              rpar=(cst.RightParen(), )),
                comparisons=(
                    cst.ComparisonTarget(
                        operator=cst.Is(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(""),
                        ),
                        comparator=cst.Name("b",
                                            lpar=(cst.LeftParen(), ),
                                            rpar=(cst.RightParen(), )),
                    ),
                    cst.ComparisonTarget(
                        operator=cst.Is(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(""),
                        ),
                        comparator=cst.Name("c",
                                            lpar=(cst.LeftParen(), ),
                                            rpar=(cst.RightParen(), )),
                    ),
                ),
            ),
            "(a)is(b)is(c)",
        ),
        # Valid expressions that look like they shouldn't parse
        (
            cst.Comparison(
                left=cst.Integer("5"),
                comparisons=(cst.ComparisonTarget(
                    operator=cst.NotIn(
                        whitespace_before=cst.SimpleWhitespace("")),
                    comparator=cst.Name("bar"),
                ), ),
            ),
            "5not in bar",
        ),
        # Validate that spacing works properly
        (
            cst.Comparison(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                left=cst.Name("foo"),
                comparisons=(cst.ComparisonTarget(
                    operator=cst.NotIn(
                        whitespace_before=cst.SimpleWhitespace("  "),
                        whitespace_between=cst.SimpleWhitespace("  "),
                        whitespace_after=cst.SimpleWhitespace("  "),
                    ),
                    comparator=cst.Name("bar"),
                ), ),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "( foo  not  in  bar )",
        ),
        # Do some complex nodes
        (
            cst.Comparison(
                left=cst.Name("baz"),
                comparisons=(cst.ComparisonTarget(
                    operator=cst.Equal(),
                    comparator=cst.Comparison(
                        lpar=(cst.LeftParen(), ),
                        left=cst.Name("foo"),
                        comparisons=(cst.ComparisonTarget(
                            operator=cst.NotIn(),
                            comparator=cst.Name("bar")), ),
                        rpar=(cst.RightParen(), ),
                    ),
                ), ),
            ),
            "baz == (foo not in bar)",
            CodeRange((1, 0), (1, 23)),
        ),
        (
            cst.Comparison(
                left=cst.Name("a"),
                comparisons=(
                    cst.ComparisonTarget(operator=cst.GreaterThan(),
                                         comparator=cst.Name("b")),
                    cst.ComparisonTarget(operator=cst.GreaterThan(),
                                         comparator=cst.Name("c")),
                ),
            ),
            "a > b > c",
            CodeRange((1, 0), (1, 9)),
        ),
        # Is safe to use with word operators if it's leading/trailing children are
        (
            cst.IfExp(
                body=cst.Comparison(
                    left=cst.Name("a"),
                    comparisons=(cst.ComparisonTarget(
                        operator=cst.GreaterThan(),
                        comparator=cst.Name(
                            "b",
                            lpar=(cst.LeftParen(), ),
                            rpar=(cst.RightParen(), ),
                        ),
                    ), ),
                ),
                test=cst.Comparison(
                    left=cst.Name("c",
                                  lpar=(cst.LeftParen(), ),
                                  rpar=(cst.RightParen(), )),
                    comparisons=(cst.ComparisonTarget(
                        operator=cst.GreaterThan(),
                        comparator=cst.Name("d")), ),
                ),
                orelse=cst.Name("e"),
                whitespace_before_if=cst.SimpleWhitespace(""),
                whitespace_after_if=cst.SimpleWhitespace(""),
            ),
            "a > (b)if(c) > d else e",
        ),
        # is safe to use with word operators if entirely surrounded in parenthesis
        (
            cst.IfExp(
                body=cst.Name("a"),
                test=cst.Comparison(
                    left=cst.Name("b"),
                    comparisons=(cst.ComparisonTarget(
                        operator=cst.GreaterThan(),
                        comparator=cst.Name("c")), ),
                    lpar=(cst.LeftParen(), ),
                    rpar=(cst.RightParen(), ),
                ),
                orelse=cst.Name("d"),
                whitespace_after_if=cst.SimpleWhitespace(""),
                whitespace_before_else=cst.SimpleWhitespace(""),
            ),
            "a if(b > c)else d",
        ),
    ))
    def test_valid(self,
                   node: cst.CSTNode,
                   code: str,
                   position: Optional[CodeRange] = None) -> None:
        self.validate_node(node,
                           code,
                           parse_expression,
                           expected_position=position)

    @data_provider((
        (
            lambda: cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.LessThan(), cst.Integer("5")), ),
                lpar=(cst.LeftParen(), ),
            ),
            "left paren without right paren",
        ),
        (
            lambda: cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.LessThan(), cst.Integer("5")), ),
                rpar=(cst.RightParen(), ),
            ),
            "right paren without left paren",
        ),
        (
            lambda: cst.Comparison(cst.Name("foo"), ()),
            "at least one ComparisonTarget",
        ),
        (
            lambda: cst.Comparison(
                left=cst.Name("foo"),
                comparisons=(cst.ComparisonTarget(
                    operator=cst.NotIn(whitespace_before=cst.SimpleWhitespace(
                        "")),
                    comparator=cst.Name("bar"),
                ), ),
            ),
            "at least one space around comparison operator",
        ),
        (
            lambda: cst.Comparison(
                left=cst.Name("foo"),
                comparisons=(cst.ComparisonTarget(
                    operator=cst.NotIn(whitespace_after=cst.SimpleWhitespace(
                        "")),
                    comparator=cst.Name("bar"),
                ), ),
            ),
            "at least one space around comparison operator",
        ),
        # multi-target comparisons
        (
            lambda: cst.Comparison(
                left=cst.Name("a"),
                comparisons=(
                    cst.ComparisonTarget(operator=cst.Is(),
                                         comparator=cst.Name("b")),
                    cst.ComparisonTarget(
                        operator=cst.Is(whitespace_before=cst.SimpleWhitespace(
                            "")),
                        comparator=cst.Name("c"),
                    ),
                ),
            ),
            "at least one space around comparison operator",
        ),
        (
            lambda: cst.Comparison(
                left=cst.Name("a"),
                comparisons=(
                    cst.ComparisonTarget(operator=cst.Is(),
                                         comparator=cst.Name("b")),
                    cst.ComparisonTarget(
                        operator=cst.Is(whitespace_after=cst.SimpleWhitespace(
                            "")),
                        comparator=cst.Name("c"),
                    ),
                ),
            ),
            "at least one space around comparison operator",
        ),
        # whitespace around the comparision itself
        # a ifb > c else d
        (
            lambda: cst.IfExp(
                body=cst.Name("a"),
                test=cst.Comparison(
                    left=cst.Name("b"),
                    comparisons=(cst.
                                 ComparisonTarget(operator=cst.GreaterThan(),
                                                  comparator=cst.Name("c")), ),
                ),
                orelse=cst.Name("d"),
                whitespace_after_if=cst.SimpleWhitespace(""),
            ),
            "Must have at least one space after 'if' keyword.",
        ),
        # a if b > celse d
        (
            lambda: cst.IfExp(
                body=cst.Name("a"),
                test=cst.Comparison(
                    left=cst.Name("b"),
                    comparisons=(cst.
                                 ComparisonTarget(operator=cst.GreaterThan(),
                                                  comparator=cst.Name("c")), ),
                ),
                orelse=cst.Name("d"),
                whitespace_before_else=cst.SimpleWhitespace(""),
            ),
            "Must have at least one space before 'else' keyword.",
        ),
    ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)