def _get_async_expr_replacement( self, node: cst.CSTNode) -> Optional[cst.CSTNode]: if m.matches(node, m.Call()): node = cast(cst.Call, node) return self._get_async_call_replacement(node) elif m.matches(node, m.Attribute()): node = cast(cst.Attribute, node) return self._get_async_attr_replacement(node) elif m.matches(node, m.UnaryOperation(operator=m.Not())): node = cast(cst.UnaryOperation, node) replacement_expression = self._get_async_expr_replacement( node.expression) if replacement_expression is not None: return node.with_changes(expression=replacement_expression) elif m.matches(node, m.BooleanOperation()): node = cast(cst.BooleanOperation, node) maybe_left = self._get_async_expr_replacement(node.left) maybe_right = self._get_async_expr_replacement(node.right) if maybe_left is not None or maybe_right is not None: left_replacement = maybe_left if maybe_left is not None else node.left right_replacement = (maybe_right if maybe_right is not None else node.right) return node.with_changes(left=left_replacement, right=right_replacement) return None
def _extract_static_bool(cls, node: cst.BaseExpression) -> Optional[bool]: if m.matches(node, m.Call()): # cannot reason about function calls return None if m.matches(node, m.UnaryOperation(operator=m.Not())): sub_value = cls._extract_static_bool( cst.ensure_type(node, cst.UnaryOperation).expression) if sub_value is None: return None return not sub_value if m.matches(node, m.Name("True")): return True if m.matches(node, m.Name("False")): return False if m.matches(node, m.BooleanOperation()): node = cst.ensure_type(node, cst.BooleanOperation) left_value = cls._extract_static_bool(node.left) right_value = cls._extract_static_bool(node.right) if m.matches(node.operator, m.Or()): if right_value is True or left_value is True: return True if m.matches(node.operator, m.And()): if right_value is False or left_value is False: return False return None
def visit_UnaryOperation(self, node: cst.UnaryOperation) -> None: if m.matches(node, m.UnaryOperation(operator=m.Not(), expression=m.Name())): # Eg: "not x". expression: cst.Name = cast(cst.Name, node.expression) if self._is_optional_type(expression): replacement_comparison = self._gen_comparison_to_none( variable_name=expression.value, operator=cst.Is() ) self.report(node, replacement=replacement_comparison)
class NotIsConditionTransformer(NoqaAwareTransformer): @m.leave( m.UnaryOperation( operator=m.Not(), expression=m.Comparison( comparisons=[m.ComparisonTarget(operator=m.Is())]), )) def replace_not_in_condition( self, _, updated_node: cst.UnaryOperation) -> cst.BaseExpression: comparison_node: cst.Comparison = cst.ensure_type( updated_node.expression, cst.Comparison) # TODO: Implement support for multiple consecutive 'not ... in B', # even if it does not make any sense in practice. return cst.Comparison( left=comparison_node.left, lpar=updated_node.lpar, rpar=updated_node.rpar, comparisons=[ comparison_node.comparisons[0].with_changes( operator=cst.IsNot()) ], )
class ShedFixers(VisitorBasedCodemodCommand): """Fix a variety of small problems. Replaces `raise NotImplemented` with `raise NotImplementedError`, and converts always-failing assert statements to explicit `raise` statements. Also includes code closely modelled on pybetter's fixers, because it's considerably faster to run all transforms in a single pass if possible. """ DESCRIPTION = "Fix a variety of style, performance, and correctness issues." @m.call_if_inside(m.Raise(exc=m.Name(value="NotImplemented"))) def leave_Name(self, _, updated_node): # noqa return updated_node.with_changes(value="NotImplementedError") def leave_Assert(self, _, updated_node): # noqa test_code = cst.Module("").code_for_node(updated_node.test) try: test_literal = literal_eval(test_code) except Exception: return updated_node if test_literal: return cst.RemovalSentinel.REMOVE if updated_node.msg is None: return cst.Raise(cst.Name("AssertionError")) return cst.Raise( cst.Call(cst.Name("AssertionError"), args=[cst.Arg(updated_node.msg)])) @m.leave( m.ComparisonTarget(comparator=oneof_names("None", "False", "True"), operator=m.Equal())) def convert_none_cmp(self, _, updated_node): """Inspired by Pybetter.""" return updated_node.with_changes(operator=cst.Is()) @m.leave( m.UnaryOperation( operator=m.Not(), expression=m.Comparison( comparisons=[m.ComparisonTarget(operator=m.In())]), )) def replace_not_in_condition(self, _, updated_node): """Also inspired by Pybetter.""" expr = cst.ensure_type(updated_node.expression, cst.Comparison) return cst.Comparison( left=expr.left, lpar=updated_node.lpar, rpar=updated_node.rpar, comparisons=[ expr.comparisons[0].with_changes(operator=cst.NotIn()) ], ) @m.leave( m.Call( lpar=[m.AtLeastN(n=1, matcher=m.LeftParen())], rpar=[m.AtLeastN(n=1, matcher=m.RightParen())], )) def remove_pointless_parens_around_call(self, _, updated_node): # This is *probably* valid, but we might have e.g. a multi-line parenthesised # chain of attribute accesses ("fluent interface"), where we need the parens. noparens = updated_node.with_changes(lpar=[], rpar=[]) try: compile(self.module.code_for_node(noparens), "<string>", "eval") return noparens except SyntaxError: return updated_node # The following methods fix https://pypi.org/project/flake8-comprehensions/ @m.leave(m.Call(func=m.Name("list"), args=[m.Arg(m.GeneratorExp())])) def replace_generator_in_call_with_comprehension(self, _, updated_node): """Fix flake8-comprehensions C400-402 and 403-404. C400-402: Unnecessary generator - rewrite as a <list/set/dict> comprehension. Note that set and dict conversions are handled by pyupgrade! """ return cst.ListComp(elt=updated_node.args[0].value.elt, for_in=updated_node.args[0].value.for_in) @m.leave( m.Call(func=m.Name("list"), args=[m.Arg(m.ListComp(), star="")]) | m.Call(func=m.Name("set"), args=[m.Arg(m.SetComp(), star="")]) | m.Call( func=m.Name("list"), args=[m.Arg(m.Call(func=oneof_names("sorted", "list")), star="")], )) def replace_unnecessary_list_around_sorted(self, _, updated_node): """Fix flake8-comprehensions C411 and C413. Unnecessary <list/reversed> call around sorted(). Also covers C411 Unnecessary list call around list comprehension for lists and sets. """ return updated_node.args[0].value @m.leave( m.Call( func=m.Name("reversed"), args=[m.Arg(m.Call(func=m.Name("sorted")), star="")], )) def replace_unnecessary_reversed_around_sorted(self, _, updated_node): """Fix flake8-comprehensions C413. Unnecessary reversed call around sorted(). """ call = updated_node.args[0].value args = list(call.args) for i, arg in enumerate(args): if m.matches(arg.keyword, m.Name("reverse")): try: val = bool( literal_eval(self.module.code_for_node(arg.value))) except Exception: args[i] = arg.with_changes( value=cst.UnaryOperation(cst.Not(), arg.value)) else: if not val: args[i] = arg.with_changes(value=cst.Name("True")) else: del args[i] args[i - 1] = remove_trailing_comma(args[i - 1]) break else: args.append( cst.Arg(keyword=cst.Name("reverse"), value=cst.Name("True"))) return call.with_changes(args=args) _sets = oneof_names("set", "frozenset") _seqs = oneof_names("list", "reversed", "sorted", "tuple") @m.leave( m.Call(func=_sets, args=[m.Arg(m.Call(func=_sets | _seqs), star="")]) | m.Call( func=oneof_names("list", "tuple"), args=[m.Arg(m.Call(func=oneof_names("list", "tuple")), star="")], ) | m.Call( func=m.Name("sorted"), args=[m.Arg(m.Call(func=_seqs), star=""), m.ZeroOrMore()], )) def replace_unnecessary_nested_calls(self, _, updated_node): """Fix flake8-comprehensions C414. Unnecessary <list/reversed/sorted/tuple> call within <list/set/sorted/tuple>().. """ return updated_node.with_changes( args=[cst.Arg(updated_node.args[0].value.args[0].value)] + list(updated_node.args[1:]), ) @m.leave( m.Call( func=oneof_names("reversed", "set", "sorted"), args=[ m.Arg(m.Subscript(slice=[m.SubscriptElement(ALL_ELEMS_SLICE)])) ], )) def replace_unnecessary_subscript_reversal(self, _, updated_node): """Fix flake8-comprehensions C415. Unnecessary subscript reversal of iterable within <reversed/set/sorted>(). """ return updated_node.with_changes( args=[cst.Arg(updated_node.args[0].value.value)], ) @m.leave( multi( m.ListComp, m.SetComp, elt=m.Name(), for_in=m.CompFor(target=m.Name(), ifs=[], inner_for_in=None, asynchronous=None), )) def replace_unnecessary_listcomp_or_setcomp(self, _, updated_node): """Fix flake8-comprehensions C416. Unnecessary <list/set> comprehension - rewrite using <list/set>(). """ if updated_node.elt.value == updated_node.for_in.target.value: func = cst.Name( "list" if isinstance(updated_node, cst.ListComp) else "set") return cst.Call(func=func, args=[cst.Arg(updated_node.for_in.iter)]) return updated_node @m.leave(m.Subscript(oneof_names("Union", "Literal"))) def reorder_union_literal_contents_none_last(self, _, updated_node): subscript = list(updated_node.slice) try: subscript.sort(key=lambda elt: elt.slice.value.value == "None") subscript[-1] = remove_trailing_comma(subscript[-1]) return updated_node.with_changes(slice=subscript) except Exception: # Single-element literals are not slices, etc. return updated_node @m.call_if_inside(m.Annotation(annotation=m.BinaryOperation())) @m.leave( m.BinaryOperation( left=m.Name("None") | m.BinaryOperation(), operator=m.BitOr(), right=m.DoNotCare(), )) def reorder_union_operator_contents_none_last(self, _, updated_node): def _has_none(node): if m.matches(node, m.Name("None")): return True elif m.matches(node, m.BinaryOperation()): return _has_none(node.left) or _has_none(node.right) else: return False node_left = updated_node.left if _has_none(node_left): return updated_node.with_changes(left=updated_node.right, right=node_left) else: return updated_node @m.leave(m.Subscript(value=m.Name("Literal"))) def flatten_literal_subscript(self, _, updated_node): new_slice = [] for item in updated_node.slice: if m.matches(item.slice.value, m.Subscript(m.Name("Literal"))): new_slice += item.slice.value.slice else: new_slice.append(item) return updated_node.with_changes(slice=new_slice) @m.leave(m.Subscript(value=m.Name("Union"))) def flatten_union_subscript(self, _, updated_node): new_slice = [] has_none = False for item in updated_node.slice: if m.matches(item.slice.value, m.Subscript(m.Name("Optional"))): new_slice += item.slice.value.slice # peel off "Optional" has_none = True elif m.matches(item.slice.value, m.Subscript(m.Name("Union"))) and m.matches( updated_node.value, item.slice.value.value): new_slice += item.slice.value.slice # peel off "Union" or "Literal" elif m.matches(item.slice.value, m.Name("None")): has_none = True else: new_slice.append(item) if has_none: new_slice.append( cst.SubscriptElement(slice=cst.Index(cst.Name("None")))) return updated_node.with_changes(slice=new_slice) @m.leave(m.Else(m.IndentedBlock([m.SimpleStatementLine([m.Pass()])]))) def discard_empty_else_blocks(self, _, updated_node): # An `else: pass` block can always simply be discarded, and libcst ensures # that an Else node can only ever occur attached to an If, While, For, or Try # node; in each case `None` is the valid way to represent "no else block". if m.findall(updated_node, m.Comment()): return updated_node # If there are any comments, keep the node return cst.RemoveFromParent() @m.leave( m.Lambda(params=m.MatchIfTrue(lambda node: ( node.star_kwarg is None and not node.kwonly_params and not node. posonly_params and isinstance(node.star_arg, cst.MaybeSentinel) and all(param.default is None for param in node.params))))) def remove_lambda_indirection(self, _, updated_node): same_args = [ m.Arg(m.Name(param.name.value), star="", keyword=None) for param in updated_node.params.params ] if m.matches(updated_node.body, m.Call(args=same_args)): return cst.ensure_type(updated_node.body, cst.Call).func return updated_node @m.leave( m.BooleanOperation( left=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]), operator=m.Or(), right=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]), )) def collapse_isinstance_checks(self, _, updated_node): left_target, left_type = updated_node.left.args right_target, right_type = updated_node.right.args if left_target.deep_equals(right_target): merged_type = cst.Arg( cst.Tuple([ cst.Element(left_type.value), cst.Element(right_type.value) ])) return updated_node.left.with_changes( args=[left_target, merged_type]) return updated_node
def remove_trailing_comma(node): # Remove the comma from this node, *unless* it's already a comma node with comments if node.comma is cst.MaybeSentinel.DEFAULT or m.findall(node, m.Comment()): return node return node.with_changes(comma=cst.MaybeSentinel.DEFAULT) MATCH_NONE = m.MatchIfTrue(lambda x: x is None) ALL_ELEMS_SLICE = m.Slice( lower=MATCH_NONE | m.Name("None"), upper=MATCH_NONE | m.Name("None"), step=MATCH_NONE | m.Name("None") | m.Integer("1") | m.UnaryOperation(m.Minus(), m.Integer("1")), ) class ShedFixers(VisitorBasedCodemodCommand): """Fix a variety of small problems. Replaces `raise NotImplemented` with `raise NotImplementedError`, and converts always-failing assert statements to explicit `raise` statements. Also includes code closely modelled on pybetter's fixers, because it's considerably faster to run all transforms in a single pass if possible. """ DESCRIPTION = "Fix a variety of style, performance, and correctness issues."
def obf_universal(self, node: cst.CSTNode, *types): if m.matches(node, m.Name()): types = ('a', 'ca', 'v', 'cv') if not types else types node = cst.ensure_type(node, cst.Name) if self.can_rename(node.value, *types): node = self.get_new_cst_name(node) elif m.matches(node, m.NameItem()): node = cst.ensure_type(node, cst.NameItem) node = node.with_changes(name=self.obf_universal(node.name)) elif m.matches(node, m.Call()): node = cst.ensure_type(node, cst.Call) if self.change_methods or self.change_functions: node = self.new_obf_function_name(node) if self.change_arguments or self.change_method_arguments: node = self.obf_function_args(node) elif m.matches(node, m.Attribute()): node = cst.ensure_type(node, cst.Attribute) value = node.value attr = node.attr self.obf_universal(value) self.obf_universal(attr) elif m.matches(node, m.AssignTarget()): node = cst.ensure_type(node, cst.AssignTarget) node = node.with_changes(target=self.obf_universal(node.target)) elif m.matches(node, m.List() | m.Tuple()): node = cst.ensure_type(node, cst.List) if m.matches( node, m.List()) else cst.ensure_type(node, cst.Tuple) new_elements = [] for el in node.elements: new_elements.append(self.obf_universal(el)) node = node.with_changes(elements=new_elements) elif m.matches(node, m.Subscript()): node = cst.ensure_type(node, cst.Subscript) new_slice = [] for el in node.slice: new_slice.append( el.with_changes(slice=self.obf_slice(el.slice))) node = node.with_changes(slice=new_slice) node = node.with_changes(value=self.obf_universal(node.value)) elif m.matches(node, m.Element()): node = cst.ensure_type(node, cst.Element) node = node.with_changes(value=self.obf_universal(node.value)) elif m.matches(node, m.Dict()): node = cst.ensure_type(node, cst.Dict) new_elements = [] for el in node.elements: new_elements.append(self.obf_universal(el)) node = node.with_changes(elements=new_elements) elif m.matches(node, m.DictElement()): node = cst.ensure_type(node, cst.DictElement) new_key = self.obf_universal(node.key) new_val = self.obf_universal(node.value) node = node.with_changes(key=new_key, value=new_val) elif m.matches(node, m.StarredDictElement()): node = cst.ensure_type(node, cst.StarredDictElement) node = node.with_changes(value=self.obf_universal(node.value)) elif m.matches(node, m.If() | m.While()): node = cst.ensure_type(node, cst.IfExp) if m.matches( node, cst.If | cst.IfExp) else cst.ensure_type(node, cst.While) node = node.with_changes(test=self.obf_universal(node.test)) elif m.matches(node, m.IfExp()): node = cst.ensure_type(node, cst.IfExp) node = node.with_changes(body=self.obf_universal(node.body)) node = node.with_changes(test=self.obf_universal(node.test)) node = node.with_changes(orelse=self.obf_universal(node.orelse)) elif m.matches(node, m.Comparison()): node = cst.ensure_type(node, cst.Comparison) new_compars = [] for target in node.comparisons: new_compars.append(self.obf_universal(target)) node = node.with_changes(left=self.obf_universal(node.left)) node = node.with_changes(comparisons=new_compars) elif m.matches(node, m.ComparisonTarget()): node = cst.ensure_type(node, cst.ComparisonTarget) node = node.with_changes( comparator=self.obf_universal(node.comparator)) elif m.matches(node, m.FormattedString()): node = cst.ensure_type(node, cst.FormattedString) new_parts = [] for part in node.parts: new_parts.append(self.obf_universal(part)) node = node.with_changes(parts=new_parts) elif m.matches(node, m.FormattedStringExpression()): node = cst.ensure_type(node, cst.FormattedStringExpression) node = node.with_changes( expression=self.obf_universal(node.expression)) elif m.matches(node, m.BinaryOperation() | m.BooleanOperation()): node = cst.ensure_type(node, cst.BinaryOperation) if m.matches( node, m.BinaryOperation()) else cst.ensure_type( node, cst.BooleanOperation) node = node.with_changes(left=self.obf_universal(node.left), right=self.obf_universal(node.right)) elif m.matches(node, m.UnaryOperation()): node = cst.ensure_type(node, cst.UnaryOperation) node = node.with_changes( expression=self.obf_universal(node.expression)) elif m.matches(node, m.ListComp()): node = cst.ensure_type(node, cst.ListComp) node = node.with_changes(elt=self.obf_universal(node.elt)) node = node.with_changes(for_in=self.obf_universal(node.for_in)) elif m.matches(node, m.DictComp()): node = cst.ensure_type(node, cst.DictComp) node = node.with_changes(key=self.obf_universal(node.key)) node = node.with_changes(value=self.obf_universal(node.value)) node = node.with_changes(for_in=self.obf_universal(node.for_in)) elif m.matches(node, m.CompFor()): node = cst.ensure_type(node, cst.CompFor) new_ifs = [] node = node.with_changes(target=self.obf_universal(node.target)) node = node.with_changes(iter=self.obf_universal(node.iter)) for el in node.ifs: new_ifs.append(self.obf_universal(el)) node = node.with_changes(ifs=new_ifs) elif m.matches(node, m.CompIf()): node = cst.ensure_type(node, cst.CompIf) node = node.with_changes(test=self.obf_universal(node.test)) elif m.matches(node, m.Integer() | m.Float() | m.SimpleString()): pass else: pass # print(node) return node
def visit_Call(self, node: cst.Call) -> None: # `self.assertTrue(x is not None)` -> `self.assertIsNotNone(x)` if m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertTrue")), args=[ m.Arg( m.Comparison(comparisons=[ m.ComparisonTarget(m.IsNot(), comparator=m.Name("None")) ])) ], ), ): new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertIsNotNone")), args=[ cst.Arg( ensure_type(node.args[0].value, cst.Comparison).left) ], ) self.report(node, replacement=new_call) # `self.assertTrue(not x is None)` -> `self.assertIsNotNone(x)` elif m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertTrue")), args=[ m.Arg(value=m.UnaryOperation( operator=m.Not(), expression=m.Comparison(comparisons=[ m.ComparisonTarget(m.Is(), comparator=m.Name("None")) ]), )) ], ), ): new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertIsNotNone")), args=[ cst.Arg( ensure_type( ensure_type(node.args[0].value, cst.UnaryOperation).expression, cst.Comparison, ).left) ], ) self.report(node, replacement=new_call) # `self.assertFalse(x is None)` -> `self.assertIsNotNone(x)` elif m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertFalse")), args=[ m.Arg( m.Comparison(comparisons=[ m.ComparisonTarget(m.Is(), comparator=m.Name("None")) ])) ], ), ): new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertIsNotNone")), args=[ cst.Arg( ensure_type(node.args[0].value, cst.Comparison).left) ], ) self.report(node, replacement=new_call) # `self.assertTrue(x is None)` -> `self.assertIsNotNone(x)) elif m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertTrue")), args=[ m.Arg( m.Comparison(comparisons=[ m.ComparisonTarget(m.Is(), comparator=m.Name("None")) ])) ], ), ): new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertIsNone")), args=[ cst.Arg( ensure_type(node.args[0].value, cst.Comparison).left) ], ) self.report(node, replacement=new_call) # `self.assertFalse(x is not None)` -> `self.assertIsNone(x)` elif m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertFalse")), args=[ m.Arg( m.Comparison(comparisons=[ m.ComparisonTarget(m.IsNot(), comparator=m.Name("None")) ])) ], ), ): new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertIsNone")), args=[ cst.Arg( ensure_type(node.args[0].value, cst.Comparison).left) ], ) self.report(node, replacement=new_call) # `self.assertFalse(not x is None)` -> `self.assertIsNone(x)` elif m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertFalse")), args=[ m.Arg(value=m.UnaryOperation( operator=m.Not(), expression=m.Comparison(comparisons=[ m.ComparisonTarget(m.Is(), comparator=m.Name("None")) ]), )) ], ), ): new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertIsNone")), args=[ cst.Arg( ensure_type( ensure_type(node.args[0].value, cst.UnaryOperation).expression, cst.Comparison, ).left) ], ) self.report(node, replacement=new_call)
def visit_Call(self, node: cst.Call) -> None: # Todo: Make use of single extract instead of having several # if else statemenets to make the code more robust and readable. if m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertTrue")), args=[ m.Arg( m.Comparison(comparisons=[ m.ComparisonTarget(operator=m.In()) ])) ], ), ): # self.assertTrue(a in b) -> self.assertIn(a, b) new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertIn")), args=[ cst.Arg( ensure_type(node.args[0].value, cst.Comparison).left), cst.Arg( ensure_type(node.args[0].value, cst.Comparison).comparisons[0].comparator), ], ) self.report(node, replacement=new_call) else: # ... -> self.assertNotIn(a, b) matched, arg1, arg2 = False, None, None if m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertTrue")), args=[ m.Arg( m.UnaryOperation( operator=m.Not(), expression=m.Comparison(comparisons=[ m.ComparisonTarget(operator=m.In()) ]), )) ], ), ): # self.assertTrue(not a in b) -> self.assertNotIn(a, b) matched = True arg1 = cst.Arg( ensure_type( ensure_type(node.args[0].value, cst.UnaryOperation).expression, cst.Comparison, ).left) arg2 = cst.Arg( ensure_type( ensure_type(node.args[0].value, cst.UnaryOperation).expression, cst.Comparison, ).comparisons[0].comparator) elif m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertTrue")), args=[ m.Arg( m.Comparison(comparisons=[ m.ComparisonTarget(m.NotIn()) ])) ], ), ): # self.assertTrue(a not in b) -> self.assertNotIn(a, b) matched = True arg1 = cst.Arg( ensure_type(node.args[0].value, cst.Comparison).left) arg2 = cst.Arg( ensure_type(node.args[0].value, cst.Comparison).comparisons[0].comparator) elif m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertFalse")), args=[ m.Arg( m.Comparison( comparisons=[m.ComparisonTarget(m.In())])) ], ), ): # self.assertFalse(a in b) -> self.assertNotIn(a, b) matched = True arg1 = cst.Arg( ensure_type(node.args[0].value, cst.Comparison).left) arg2 = cst.Arg( ensure_type(node.args[0].value, cst.Comparison).comparisons[0].comparator) if matched: new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertNotIn")), args=[arg1, arg2], ) self.report(node, replacement=new_call)
def visit_Call(self, node: cst.Call) -> None: match_compare_is_none = m.ComparisonTarget( m.SaveMatchedNode( m.OneOf(m.Is(), m.IsNot()), "comparison_type", ), comparator=m.Name("None"), ) result = m.extract( node, m.Call( func=m.Attribute( value=m.Name("self"), attr=m.SaveMatchedNode( m.OneOf(m.Name("assertTrue"), m.Name("assertFalse")), "assertion_name", ), ), args=[ m.Arg( m.SaveMatchedNode( m.OneOf( m.Comparison( comparisons=[match_compare_is_none]), m.UnaryOperation( operator=m.Not(), expression=m.Comparison( comparisons=[match_compare_is_none]), ), ), "argument", )) ], ), ) if result: assertion_name = result["assertion_name"] if isinstance(assertion_name, Sequence): assertion_name = assertion_name[0] argument = result["argument"] if isinstance(argument, Sequence): argument = argument[0] comparison_type = result["comparison_type"] if isinstance(comparison_type, Sequence): comparison_type = comparison_type[0] if m.matches(argument, m.Comparison()): assertion_argument = ensure_type(argument, cst.Comparison).left else: assertion_argument = ensure_type( ensure_type(argument, cst.UnaryOperation).expression, cst.Comparison).left negations_seen = 0 if m.matches(assertion_name, m.Name("assertFalse")): negations_seen += 1 if m.matches(argument, m.UnaryOperation()): negations_seen += 1 if m.matches(comparison_type, m.IsNot()): negations_seen += 1 new_attr = "assertIsNone" if negations_seen % 2 == 0 else "assertIsNotNone" new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name(new_attr)), args=[cst.Arg(assertion_argument)], ) if new_call is not node: self.report(node, replacement=new_call)