def alter_operator( self, original_op: Union[cst.Equal, cst.NotEqual]) -> Union[cst.Is, cst.IsNot]: return (cst.IsNot( whitespace_before=original_op.whitespace_before, whitespace_after=original_op.whitespace_after, ) if isinstance(original_op, cst.NotEqual) else cst.Is( whitespace_before=original_op.whitespace_before, whitespace_after=original_op.whitespace_after, ))
def visit_BooleanOperation(self, node: cst.BooleanOperation) -> None: left_expression: cst.BaseExpression = node.left right_expression: cst.BaseExpression = node.right if m.matches(node.left, m.Name()): # Eg: "x and y". left_expression = cast(cst.Name, left_expression) if self._is_optional_type(left_expression): replacement_comparison = self._gen_comparison_to_none( variable_name=left_expression.value, operator=cst.IsNot() ) self.report( node, replacement=node.with_changes(left=replacement_comparison) ) if m.matches(right_expression, m.Name()): # Eg: "x and y". right_expression = cast(cst.Name, right_expression) if self._is_optional_type(right_expression): replacement_comparison = self._gen_comparison_to_none( variable_name=right_expression.value, operator=cst.IsNot() ) self.report( node, replacement=node.with_changes(right=replacement_comparison) )
def replace_not_in_condition( self, _, updated_node: cst.UnaryOperation) -> cst.BaseExpression: comparison_node: cst.Comparison = cst.ensure_type( updated_node.expression, cst.Comparison) # TODO: Implement support for multiple consecutive 'not ... in B', # even if it does not make any sense in practice. return cst.Comparison( left=comparison_node.left, lpar=updated_node.lpar, rpar=updated_node.rpar, comparisons=[ comparison_node.comparisons[0].with_changes( operator=cst.IsNot()) ], )
def leave_If(self, original_node: cst.If) -> None: changes: Dict[str, cst.CSTNode] = {} test_expression: cst.BaseExpression = original_node.test if m.matches(test_expression, m.Name()): # We are inside a simple check such as "if x". test_expression = cast(cst.Name, test_expression) if self._is_optional_type(test_expression): # We want to replace "if x" with "if x is not None". replacement_comparison: cst.Comparison = self._gen_comparison_to_none( variable_name=test_expression.value, operator=cst.IsNot() ) changes["test"] = replacement_comparison orelse = original_node.orelse if orelse is not None and m.matches(orelse, m.If()): # We want to catch this case upon leaving an `If` node so that we generate an `elif` statement correctly. # We check if the orelse node was reported, and if so, remove the report and generate a new report on # the current parent `If` node. new_reports = [] orelse_report: Optional[CstLintRuleReport] = None for report in self.context.reports: if isinstance(report, CstLintRuleReport): # Check whether the lint rule code matches this lint rule's code so we don't remove another # lint rule's report. if report.node is orelse and report.code == self.__class__.__name__: orelse_report = report else: new_reports.append(report) else: new_reports.append(report) if orelse_report is not None: self.context.reports = new_reports replacement_orelse = orelse_report.replacement_node changes["orelse"] = cst.ensure_type(replacement_orelse, cst.CSTNode) if changes: self.report( original_node, replacement=original_node.with_changes(**changes) )
class ComparisonTest(CSTNodeTest): @data_provider(( # Simple comparison statements ( cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.LessThan(), cst.Integer("5")), ), ), "foo < 5", ), ( cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.NotEqual(), cst.Integer("5")), ), ), "foo != 5", ), ( cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.Is(), cst.Name("True")), )), "foo is True", ), ( cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.IsNot(), cst.Name("False")), ), ), "foo is not False", ), ( cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.In(), cst.Name("bar")), )), "foo in bar", ), ( cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.NotIn(), cst.Name("bar")), ), ), "foo not in bar", ), # Comparison with parens ( cst.Comparison( lpar=(cst.LeftParen(), ), left=cst.Name("foo"), comparisons=(cst.ComparisonTarget( operator=cst.NotIn(), comparator=cst.Name("bar")), ), rpar=(cst.RightParen(), ), ), "(foo not in bar)", ), ( cst.Comparison( left=cst.Name("a", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), comparisons=( cst.ComparisonTarget( operator=cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), comparator=cst.Name("b", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), ), cst.ComparisonTarget( operator=cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), comparator=cst.Name("c", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), ), ), ), "(a)is(b)is(c)", ), # Valid expressions that look like they shouldn't parse ( cst.Comparison( left=cst.Integer("5"), comparisons=(cst.ComparisonTarget( operator=cst.NotIn( whitespace_before=cst.SimpleWhitespace("")), comparator=cst.Name("bar"), ), ), ), "5not in bar", ), # Validate that spacing works properly ( cst.Comparison( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), left=cst.Name("foo"), comparisons=(cst.ComparisonTarget( operator=cst.NotIn( whitespace_before=cst.SimpleWhitespace(" "), whitespace_between=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), comparator=cst.Name("bar"), ), ), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( foo not in bar )", ), # Do some complex nodes ( cst.Comparison( left=cst.Name("baz"), comparisons=(cst.ComparisonTarget( operator=cst.Equal(), comparator=cst.Comparison( lpar=(cst.LeftParen(), ), left=cst.Name("foo"), comparisons=(cst.ComparisonTarget( operator=cst.NotIn(), comparator=cst.Name("bar")), ), rpar=(cst.RightParen(), ), ), ), ), ), "baz == (foo not in bar)", CodeRange((1, 0), (1, 23)), ), ( cst.Comparison( left=cst.Name("a"), comparisons=( cst.ComparisonTarget(operator=cst.GreaterThan(), comparator=cst.Name("b")), cst.ComparisonTarget(operator=cst.GreaterThan(), comparator=cst.Name("c")), ), ), "a > b > c", CodeRange((1, 0), (1, 9)), ), # Is safe to use with word operators if it's leading/trailing children are ( cst.IfExp( body=cst.Comparison( left=cst.Name("a"), comparisons=(cst.ComparisonTarget( operator=cst.GreaterThan(), comparator=cst.Name( "b", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), ), ), ), test=cst.Comparison( left=cst.Name("c", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), comparisons=(cst.ComparisonTarget( operator=cst.GreaterThan(), comparator=cst.Name("d")), ), ), orelse=cst.Name("e"), whitespace_before_if=cst.SimpleWhitespace(""), whitespace_after_if=cst.SimpleWhitespace(""), ), "a > (b)if(c) > d else e", ), # is safe to use with word operators if entirely surrounded in parenthesis ( cst.IfExp( body=cst.Name("a"), test=cst.Comparison( left=cst.Name("b"), comparisons=(cst.ComparisonTarget( operator=cst.GreaterThan(), comparator=cst.Name("c")), ), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), orelse=cst.Name("d"), whitespace_after_if=cst.SimpleWhitespace(""), whitespace_before_else=cst.SimpleWhitespace(""), ), "a if(b > c)else d", ), )) def test_valid(self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None) -> None: self.validate_node(node, code, parse_expression, expected_position=position) @data_provider(( ( lambda: cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.LessThan(), cst.Integer("5")), ), lpar=(cst.LeftParen(), ), ), "left paren without right paren", ), ( lambda: cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.LessThan(), cst.Integer("5")), ), rpar=(cst.RightParen(), ), ), "right paren without left paren", ), ( lambda: cst.Comparison(cst.Name("foo"), ()), "at least one ComparisonTarget", ), ( lambda: cst.Comparison( left=cst.Name("foo"), comparisons=(cst.ComparisonTarget( operator=cst.NotIn(whitespace_before=cst.SimpleWhitespace( "")), comparator=cst.Name("bar"), ), ), ), "at least one space around comparison operator", ), ( lambda: cst.Comparison( left=cst.Name("foo"), comparisons=(cst.ComparisonTarget( operator=cst.NotIn(whitespace_after=cst.SimpleWhitespace( "")), comparator=cst.Name("bar"), ), ), ), "at least one space around comparison operator", ), # multi-target comparisons ( lambda: cst.Comparison( left=cst.Name("a"), comparisons=( cst.ComparisonTarget(operator=cst.Is(), comparator=cst.Name("b")), cst.ComparisonTarget( operator=cst.Is(whitespace_before=cst.SimpleWhitespace( "")), comparator=cst.Name("c"), ), ), ), "at least one space around comparison operator", ), ( lambda: cst.Comparison( left=cst.Name("a"), comparisons=( cst.ComparisonTarget(operator=cst.Is(), comparator=cst.Name("b")), cst.ComparisonTarget( operator=cst.Is(whitespace_after=cst.SimpleWhitespace( "")), comparator=cst.Name("c"), ), ), ), "at least one space around comparison operator", ), # whitespace around the comparision itself # a ifb > c else d ( lambda: cst.IfExp( body=cst.Name("a"), test=cst.Comparison( left=cst.Name("b"), comparisons=(cst. ComparisonTarget(operator=cst.GreaterThan(), comparator=cst.Name("c")), ), ), orelse=cst.Name("d"), whitespace_after_if=cst.SimpleWhitespace(""), ), "Must have at least one space after 'if' keyword.", ), # a if b > celse d ( lambda: cst.IfExp( body=cst.Name("a"), test=cst.Comparison( left=cst.Name("b"), comparisons=(cst. ComparisonTarget(operator=cst.GreaterThan(), comparator=cst.Name("c")), ), ), orelse=cst.Name("d"), whitespace_before_else=cst.SimpleWhitespace(""), ), "Must have at least one space before 'else' keyword.", ), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)