def test_gen_dotted_names(self) -> None: names = {name for name, node in _gen_dotted_names(cst.Name(value="a"))} self.assertEqual(names, {"a"}) names = { name for name, node in _gen_dotted_names( cst.Attribute(value=cst.Name(value="a"), attr=cst.Name(value="b"))) } self.assertEqual(names, {"a.b", "a"}) names = { name for name, node in _gen_dotted_names( cst.Attribute( value=cst.Call( func=cst.Attribute( value=cst.Attribute(value=cst.Name(value="a"), attr=cst.Name(value="b")), attr=cst.Name(value="c"), ), args=[], ), attr=cst.Name(value="d"), )) } self.assertEqual(names, {"a.b.c", "a.b", "a"})
def datetime_datetime_replace( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.Call: self._update_imports() return updated_node.with_changes( func=cst.Attribute( value=cst.Attribute( value=cst.Name(value="datetime"), attr=cst.Name(value="datetime"), ), attr=cst.Name(value="now"), ), args=[cst.Arg(value=cst.Name(value="UTC"))], )
def sampleop_to_logpdf(cst_generator, *args, **kwargs): name = kwargs.pop("var_name") return cst.Call( cst.Attribute(cst_generator(*args, **kwargs), cst.Name("logpdf_sum")), [cst.Arg(name)], )
def leave_SimpleString( self, original_node: cst.SimpleString, updated_node: cst.SimpleString ) -> Union[cst.SimpleString, cst.Attribute]: value = updated_node.evaluated_value if value in CST_DIR: return cst.Attribute(cst.Name("cst"), cst.Name(value)) return updated_node
def choice_ast(rng_key): return cst.Call( func=cst.Attribute( value=cst.Attribute(cst.Name("mcx"), cst.Name("jax")), attr=cst.Name("choice"), ), args=[ cst.Arg(rng_key), cst.Arg( cst.Subscript( cst.Attribute(cst.Name(nodes[0].name), cst.Name("shape")), [cst.SubscriptElement(cst.Index(cst.Integer("0")))], ) ), ], )
def to_sampler(cst_generator, *args, **kwargs): rng_key = kwargs.pop("rng_key") return cst.Call( func=cst.Attribute(value=cst_generator(*args, **kwargs), attr=cst.Name("sample")), args=[cst.Arg(value=rng_key)], )
def test_with_dots(self) -> None: self.assertEqual("foo", util.with_dots(cst.Name(value="foo"))) self.assertEqual( "foo.bar.baz", util.with_dots( cst.Attribute( value=cst.Attribute( value=cst.Name("foo"), attr=cst.Name("bar"), ), attr=cst.Name("baz"), )), ) with self.assertRaisesRegex(TypeError, "Can't with_dots"): util.with_dots("foo.bar") # type: ignore
def pluck_asyncio_gather_expression_from_yield_list_or_list_comp( node: cst.Yield, ) -> cst.BaseExpression: return cst.Call( func=cst.Attribute(value=cst.Name("asyncio"), attr=cst.Name("gather")), args=[cst.Arg(value=node.value, star="*")], )
def get_name_node(name: str) -> Union[cst.Name, cst.Attribute]: # Inverse `_get_alias_name`. if "." not in name: return cst.Name(name) names = name.split(".") value = get_name_node(".".join(names[:-1])) attr = get_name_node(names[-1]) return cst.Attribute(value=value, attr=attr) # type: ignore
def annotation(self) -> typing.Union[cst.Name, cst.Attribute]: first_name, *rest = (self.module.split(".") + [self.name] if self.module else [self.name]) try: expr: typing.Union[cst.Name, cst.Attribute] = cst.Name(first_name) for name in rest: expr = cst.Attribute(expr, cst.Name(name)) except cst._nodes.base.CSTValidationError: return cst.Name("Unknown") return expr
def leave_Call(self, node: cst.Call, updated_node: cst.Call) -> cst.Call: if not self.in_coroutine(self.coroutine_stack): return updated_node if m.matches(updated_node, gen_sleep_matcher): self.required_imports.add("asyncio") return updated_node.with_changes(func=cst.Attribute( value=cst.Name("asyncio"), attr=cst.Name("sleep"))) return updated_node
def leave_SimpleString( self, original_node: cst.SimpleString, updated_node: cst.SimpleString ) -> Union[cst.SimpleString, cst.Attribute]: try: value = ast.literal_eval(updated_node.value) except SyntaxError: return updated_node if value in CST_DIR: return cst.Attribute(cst.Name("cst"), cst.Name(value)) return updated_node
def leave_Attribute( self, original: libcst.Attribute, updated: libcst.Attribute ) -> Any: if m.matches(updated.value, m.Name("six")): if m.matches(updated.attr, m.Name()): if updated.attr.value in _IO_OBJECTS: return libcst.Attribute( value=libcst.Name("io"), attr=updated.attr, ) return updated
def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # Matches calls with symbols without the wx prefix for symbol, matcher, renamed in self.matchers_short_map: if symbol in self.wx_imports and matchers.matches( updated_node, matcher): # Remove the symbol's import RemoveImportsVisitor.remove_unused_import_by_node( self.context, original_node) # Add import of top level wx package AddImportsVisitor.add_needed_import(self.context, "wx") # Return updated node if isinstance(renamed, tuple): return updated_node.with_changes(func=cst.Attribute( value=cst.Attribute(value=cst.Name(value="wx"), attr=cst.Name(value=renamed[0])), attr=cst.Name(value=renamed[1]), )) return updated_node.with_changes(func=cst.Attribute( value=cst.Name(value="wx"), attr=cst.Name(value=renamed))) # Matches full calls like wx.MySymbol for matcher, renamed in self.matchers_full_map: if matchers.matches(updated_node, matcher): if isinstance(renamed, tuple): return updated_node.with_changes(func=cst.Attribute( value=cst.Attribute(value=cst.Name(value="wx"), attr=cst.Name(value=renamed[0])), attr=cst.Name(value=renamed[1]), )) return updated_node.with_changes( func=updated_node.func.with_changes(attr=cst.Name( value=renamed))) # Returns updated node return updated_node
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute) -> cst.Attribute: for matcher in self.matchers: if matchers.matches(updated_node, matcher): # Ensure that wx.adv is imported AddImportsVisitor.add_needed_import(self.context, "wx.adv") # Return modified node return updated_node.with_changes(value=cst.Attribute( value=cst.Name(value="wx"), attr=cst.Name(value="adv"))) return updated_node
def leave_Call(self, original: libcst.Call, updated: libcst.Call) -> libcst.Call: if m.matches(updated.func.value, m.Name("six")): for orig_name, updated_name in _CONVERSION_MAP.items(): if m.matches(updated.func.attr, m.Name(orig_name)): if len(updated.args) != 1: self.warn( f"Odd six.{orig_name} call does not have one argument. Cannot perform substitution." ) continue value = updated.args[0].value return libcst.Call(func=libcst.Attribute( value=value, attr=libcst.Name(value=updated_name), )) return updated
def leave_Subscript( self, original_node: libcst.Subscript, updated_node: Union[libcst.Subscript, libcst.SimpleString], ) -> Union[libcst.Subscript, libcst.SimpleString]: if libcst.matchers.matches(original_node.value, libcst.matchers.Name("PathLike")): name_node = libcst.Attribute( value=libcst.Name( value="os", lpar=[], rpar=[], ), attr=libcst.Name(value="PathLike"), ) node_as_string = libcst.parse_module("").code_for_node( updated_node.with_changes(value=name_node)) updated_node = libcst.SimpleString(f"'{node_as_string}'") return updated_node
def test_collect_targets(): tree = cst.parse_module(''' x = [0, 1] x[0] = 1 x.attr = 2 ''') x = cst.Name(value='x') x0 = cst.Subscript( value=x, slice=[cst.SubscriptElement(slice=cst.Index(value=cst.Integer('0')))], ) xa = cst.Attribute( value=x, attr=cst.Name('attr'), ) golds = x, x0, xa targets = collect_targets(tree) assert all(t.deep_equals(g) for t, g in zip(targets, golds))
import libcst import libcst.matchers as m from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand _CONVERSION_MAP = { "class_types": libcst.Name("type"), "integer_types": libcst.Name("int"), "string_types": libcst.Name("str"), "text_type": libcst.Name("str"), "binary_type": libcst.Name("bytes"), # TODO need to import sys automatically if we do this "MAXSIZE": libcst.Attribute(value=libcst.Name("sys"), attr=libcst.Name("maxsize")), } class ConvertSixConstants(VisitorBasedCodemodCommand): def leave_Attribute(self, original: libcst.Attribute, updated: libcst.Attribute) -> Any: if m.matches(updated.value, m.Name("six")): if m.matches(updated.attr, m.Name()): return _CONVERSION_MAP.get(updated.attr.value, updated) return updated
class ImportCreateTest(CSTNodeTest): @data_provider( ( # Simple import statement { "node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)), "code": "import foo", }, { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), "code": "import foo.bar", }, { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), "code": "import foo.bar", }, # Comma-separated list of imports { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")) ), ) ), "code": "import foo.bar, foo.baz", "expected_position": CodeRange((1, 0), (1, 23)), }, # Import with an alias { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), ), ) ), "code": "import foo.bar as baz", }, # Import with an alias, comma separated { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")), asname=cst.AsName(cst.Name("bar")), ), ) ), "code": "import foo.bar as baz, foo.baz as bar", }, # Combine for fun and profit { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), ), cst.ImportAlias( cst.Attribute(cst.Name("insta"), cst.Name("gram")) ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")) ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName(cst.Name("ut")) ), ) ), "code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut", }, # Verify whitespace works everywhere. { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute( cst.Name("foo"), cst.Name("bar"), dot=cst.Dot( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), asname=cst.AsName( cst.Name("baz"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName( cst.Name("ut"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), whitespace_after_import=cst.SimpleWhitespace(" "), ), "code": "import foo . bar as baz , unittest as ut", "expected_position": CodeRange((1, 0), (1, 46)), }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider( ( { "get_node": lambda: cst.Import(names=()), "expected_re": "at least one ImportAlias", }, { "get_node": lambda: cst.Import(names=(cst.ImportAlias(cst.Name("")),)), "expected_re": "empty name identifier", }, { "get_node": lambda: cst.Import( names=( cst.ImportAlias(cst.Attribute(cst.Name(""), cst.Name("bla"))), ) ), "expected_re": "empty name identifier", }, { "get_node": lambda: cst.Import( names=( cst.ImportAlias(cst.Attribute(cst.Name("bla"), cst.Name(""))), ) ), "expected_re": "empty name identifier", }, { "get_node": lambda: cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), comma=cst.Comma(), ), ) ), "expected_re": "trailing comma", }, { "get_node": lambda: cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ), whitespace_after_import=cst.SimpleWhitespace(""), ), "expected_re": "at least one space", }, ) ) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
class ImportParseTest(CSTNodeTest): @data_provider( ( # Simple import statement { "node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)), "code": "import foo", }, { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), "code": "import foo.bar", }, { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), "code": "import foo.bar", }, # Comma-separated list of imports { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")) ), ) ), "code": "import foo.bar, foo.baz", }, # Import with an alias { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), ), ) ), "code": "import foo.bar as baz", }, # Import with an alias, comma separated { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")), asname=cst.AsName(cst.Name("bar")), ), ) ), "code": "import foo.bar as baz, foo.baz as bar", }, # Combine for fun and profit { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(cst.Name("baz")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Attribute(cst.Name("insta"), cst.Name("gram")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("baz")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName(cst.Name("ut")) ), ) ), "code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut", }, # Verify whitespace works everywhere. { "node": cst.Import( names=( cst.ImportAlias( cst.Attribute( cst.Name("foo"), cst.Name("bar"), dot=cst.Dot( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), asname=cst.AsName( cst.Name("baz"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.ImportAlias( cst.Name("unittest"), asname=cst.AsName( cst.Name("ut"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), whitespace_after_import=cst.SimpleWhitespace(" "), ), "code": "import foo . bar as baz , unittest as ut", }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node( parser=lambda code: ensure_type( parse_statement(code), cst.SimpleStatementLine ).body[0], **kwargs, )
def to_attribute_cst(value, attr): return cst.Attribute(value, attr)
def visit_Call(self, node: cst.Call) -> None: # Todo: Make use of single extract instead of having several # if else statemenets to make the code more robust and readable. if m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertTrue")), args=[ m.Arg( m.Comparison(comparisons=[ m.ComparisonTarget(operator=m.In()) ])) ], ), ): # self.assertTrue(a in b) -> self.assertIn(a, b) new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertIn")), args=[ cst.Arg( ensure_type(node.args[0].value, cst.Comparison).left), cst.Arg( ensure_type(node.args[0].value, cst.Comparison).comparisons[0].comparator), ], ) self.report(node, replacement=new_call) else: # ... -> self.assertNotIn(a, b) matched, arg1, arg2 = False, None, None if m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertTrue")), args=[ m.Arg( m.UnaryOperation( operator=m.Not(), expression=m.Comparison(comparisons=[ m.ComparisonTarget(operator=m.In()) ]), )) ], ), ): # self.assertTrue(not a in b) -> self.assertNotIn(a, b) matched = True arg1 = cst.Arg( ensure_type( ensure_type(node.args[0].value, cst.UnaryOperation).expression, cst.Comparison, ).left) arg2 = cst.Arg( ensure_type( ensure_type(node.args[0].value, cst.UnaryOperation).expression, cst.Comparison, ).comparisons[0].comparator) elif m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertTrue")), args=[ m.Arg( m.Comparison(comparisons=[ m.ComparisonTarget(m.NotIn()) ])) ], ), ): # self.assertTrue(a not in b) -> self.assertNotIn(a, b) matched = True arg1 = cst.Arg( ensure_type(node.args[0].value, cst.Comparison).left) arg2 = cst.Arg( ensure_type(node.args[0].value, cst.Comparison).comparisons[0].comparator) elif m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertFalse")), args=[ m.Arg( m.Comparison( comparisons=[m.ComparisonTarget(m.In())])) ], ), ): # self.assertFalse(a in b) -> self.assertNotIn(a, b) matched = True arg1 = cst.Arg( ensure_type(node.args[0].value, cst.Comparison).left) arg2 = cst.Arg( ensure_type(node.args[0].value, cst.Comparison).comparisons[0].comparator) if matched: new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertNotIn")), args=[arg1, arg2], ) self.report(node, replacement=new_call)
class AttributeTest(CSTNodeTest): @data_provider( ( # Simple attribute access { "node": cst.Attribute(cst.Name("foo"), cst.Name("bar")), "code": "foo.bar", "parser": parse_expression, "expected_position": CodeRange((1, 0), (1, 7)), }, # Parenthesized attribute access { "node": cst.Attribute( lpar=(cst.LeftParen(),), value=cst.Name("foo"), attr=cst.Name("bar"), rpar=(cst.RightParen(),), ), "code": "(foo.bar)", "parser": parse_expression, "expected_position": CodeRange((1, 1), (1, 8)), }, # Make sure that spacing works { "node": cst.Attribute( lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), value=cst.Name("foo"), dot=cst.Dot( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), attr=cst.Name("bar"), rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), ), "code": "( foo . bar )", "parser": parse_expression, "expected_position": CodeRange((1, 2), (1, 11)), }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider( ( { "get_node": ( lambda: cst.Attribute( cst.Name("foo"), cst.Name("bar"), lpar=(cst.LeftParen(),) ) ), "expected_re": "left paren without right paren", }, { "get_node": ( lambda: cst.Attribute( cst.Name("foo"), cst.Name("bar"), rpar=(cst.RightParen(),) ) ), "expected_re": "right paren without left paren", }, ) ) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
class StatementTest(UnitTest): @data_provider( ( # Simple imports that are already absolute. (None, "from a.b import c", "a.b"), ("x.y.z", "from a.b import c", "a.b"), # Relative import that can't be resolved due to missing module. (None, "from ..w import c", None), # Relative import that goes past the module level. ("x", "from ...y import z", None), ("x.y.z", "from .....w import c", None), ("x.y.z", "from ... import c", None), # Correct resolution of absolute from relative modules. ("x.y.z", "from . import c", "x.y"), ("x.y.z", "from .. import c", "x"), ("x.y.z", "from .w import c", "x.y.w"), ("x.y.z", "from ..w import c", "x.w"), ("x.y.z", "from ...w import c", "w"), ) ) def test_get_absolute_module( self, module: Optional[str], importfrom: str, output: Optional[str], ) -> None: node = ensure_type(cst.parse_statement(importfrom), cst.SimpleStatementLine) assert len(node.body) == 1, "Unexpected number of statements!" import_node = ensure_type(node.body[0], cst.ImportFrom) self.assertEqual(get_absolute_module_for_import(module, import_node), output) if output is None: with self.assertRaises(Exception): get_absolute_module_for_import_or_raise(module, import_node) else: self.assertEqual( get_absolute_module_for_import_or_raise(module, import_node), output ) @data_provider( ( # Nodes without an asname (cst.ImportAlias(name=cst.Name("foo")), "foo", None), ( cst.ImportAlias(name=cst.Attribute(cst.Name("foo"), cst.Name("bar"))), "foo.bar", None, ), # Nodes with an asname ( cst.ImportAlias( name=cst.Name("foo"), asname=cst.AsName(name=cst.Name("baz")) ), "foo", "baz", ), ( cst.ImportAlias( name=cst.Attribute(cst.Name("foo"), cst.Name("bar")), asname=cst.AsName(name=cst.Name("baz")), ), "foo.bar", "baz", ), ) ) def test_importalias_helpers( self, alias_node: cst.ImportAlias, full_name: str, alias: Optional[str] ) -> None: self.assertEqual(alias_node.evaluated_name, full_name) self.assertEqual(alias_node.evaluated_alias, alias)
def visit_Call(self, node: cst.Call) -> None: match_compare_is_none = m.ComparisonTarget( m.SaveMatchedNode( m.OneOf(m.Is(), m.IsNot()), "comparison_type", ), comparator=m.Name("None"), ) result = m.extract( node, m.Call( func=m.Attribute( value=m.Name("self"), attr=m.SaveMatchedNode( m.OneOf(m.Name("assertTrue"), m.Name("assertFalse")), "assertion_name", ), ), args=[ m.Arg( m.SaveMatchedNode( m.OneOf( m.Comparison( comparisons=[match_compare_is_none]), m.UnaryOperation( operator=m.Not(), expression=m.Comparison( comparisons=[match_compare_is_none]), ), ), "argument", )) ], ), ) if result: assertion_name = result["assertion_name"] if isinstance(assertion_name, Sequence): assertion_name = assertion_name[0] argument = result["argument"] if isinstance(argument, Sequence): argument = argument[0] comparison_type = result["comparison_type"] if isinstance(comparison_type, Sequence): comparison_type = comparison_type[0] if m.matches(argument, m.Comparison()): assertion_argument = ensure_type(argument, cst.Comparison).left else: assertion_argument = ensure_type( ensure_type(argument, cst.UnaryOperation).expression, cst.Comparison).left negations_seen = 0 if m.matches(assertion_name, m.Name("assertFalse")): negations_seen += 1 if m.matches(argument, m.UnaryOperation()): negations_seen += 1 if m.matches(comparison_type, m.IsNot()): negations_seen += 1 new_attr = "assertIsNone" if negations_seen % 2 == 0 else "assertIsNotNone" new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name(new_attr)), args=[cst.Arg(assertion_argument)], ) if new_call is not node: self.report(node, replacement=new_call)
def visit_Call(self, node: cst.Call) -> None: # `self.assertTrue(x is not None)` -> `self.assertIsNotNone(x)` if m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertTrue")), args=[ m.Arg( m.Comparison(comparisons=[ m.ComparisonTarget(m.IsNot(), comparator=m.Name("None")) ])) ], ), ): new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertIsNotNone")), args=[ cst.Arg( ensure_type(node.args[0].value, cst.Comparison).left) ], ) self.report(node, replacement=new_call) # `self.assertTrue(not x is None)` -> `self.assertIsNotNone(x)` elif m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertTrue")), args=[ m.Arg(value=m.UnaryOperation( operator=m.Not(), expression=m.Comparison(comparisons=[ m.ComparisonTarget(m.Is(), comparator=m.Name("None")) ]), )) ], ), ): new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertIsNotNone")), args=[ cst.Arg( ensure_type( ensure_type(node.args[0].value, cst.UnaryOperation).expression, cst.Comparison, ).left) ], ) self.report(node, replacement=new_call) # `self.assertFalse(x is None)` -> `self.assertIsNotNone(x)` elif m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertFalse")), args=[ m.Arg( m.Comparison(comparisons=[ m.ComparisonTarget(m.Is(), comparator=m.Name("None")) ])) ], ), ): new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertIsNotNone")), args=[ cst.Arg( ensure_type(node.args[0].value, cst.Comparison).left) ], ) self.report(node, replacement=new_call) # `self.assertTrue(x is None)` -> `self.assertIsNotNone(x)) elif m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertTrue")), args=[ m.Arg( m.Comparison(comparisons=[ m.ComparisonTarget(m.Is(), comparator=m.Name("None")) ])) ], ), ): new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertIsNone")), args=[ cst.Arg( ensure_type(node.args[0].value, cst.Comparison).left) ], ) self.report(node, replacement=new_call) # `self.assertFalse(x is not None)` -> `self.assertIsNone(x)` elif m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertFalse")), args=[ m.Arg( m.Comparison(comparisons=[ m.ComparisonTarget(m.IsNot(), comparator=m.Name("None")) ])) ], ), ): new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertIsNone")), args=[ cst.Arg( ensure_type(node.args[0].value, cst.Comparison).left) ], ) self.report(node, replacement=new_call) # `self.assertFalse(not x is None)` -> `self.assertIsNone(x)` elif m.matches( node, m.Call( func=m.Attribute(value=m.Name("self"), attr=m.Name("assertFalse")), args=[ m.Arg(value=m.UnaryOperation( operator=m.Not(), expression=m.Comparison(comparisons=[ m.ComparisonTarget(m.Is(), comparator=m.Name("None")) ]), )) ], ), ): new_call = node.with_changes( func=cst.Attribute(value=cst.Name("self"), attr=cst.Name("assertIsNone")), args=[ cst.Arg( ensure_type( ensure_type(node.args[0].value, cst.UnaryOperation).expression, cst.Comparison, ).left) ], ) self.report(node, replacement=new_call)
class CallTest(CSTNodeTest): @data_provider(( # Simple call { "node": cst.Call(cst.Name("foo")), "code": "foo()", "parser": parse_expression, "expected_position": None, }, { "node": cst.Call(cst.Name("foo"), whitespace_before_args=cst.SimpleWhitespace(" ")), "code": "foo( )", "parser": parse_expression, "expected_position": None, }, # Call with attribute dereference { "node": cst.Call(cst.Attribute(cst.Name("foo"), cst.Name("bar"))), "code": "foo.bar()", "parser": parse_expression, "expected_position": None, }, # Positional arguments render test { "node": cst.Call(cst.Name("foo"), (cst.Arg(cst.Integer("1")), )), "code": "foo(1)", "parser": None, "expected_position": None, }, { "node": cst.Call( cst.Name("foo"), ( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), "code": "foo(1, 2, 3)", "parser": None, "expected_position": None, }, # Positional arguments parse test { "node": cst.Call(cst.Name("foo"), (cst.Arg(value=cst.Integer("1")), )), "code": "foo(1)", "parser": parse_expression, "expected_position": None, }, { "node": cst.Call( cst.Name("foo"), (cst.Arg( value=cst.Integer("1"), whitespace_after_arg=cst.SimpleWhitespace(" "), ), ), whitespace_after_func=cst.SimpleWhitespace(" "), whitespace_before_args=cst.SimpleWhitespace(" "), ), "code": "foo ( 1 )", "parser": parse_expression, "expected_position": None, }, { "node": cst.Call( cst.Name("foo"), (cst.Arg( value=cst.Integer("1"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), whitespace_after_func=cst.SimpleWhitespace(" "), whitespace_before_args=cst.SimpleWhitespace(" "), ), "code": "foo ( 1, )", "parser": parse_expression, "expected_position": None, }, { "node": cst.Call( cst.Name("foo"), ( cst.Arg( value=cst.Integer("1"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( value=cst.Integer("2"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg(value=cst.Integer("3")), ), ), "code": "foo(1, 2, 3)", "parser": parse_expression, "expected_position": None, }, # Keyword arguments render test { "node": cst.Call( cst.Name("foo"), (cst.Arg(keyword=cst.Name("one"), value=cst.Integer("1")), ), ), "code": "foo(one = 1)", "parser": None, "expected_position": None, }, { "node": cst.Call( cst.Name("foo"), ( cst.Arg(keyword=cst.Name("one"), value=cst.Integer("1")), cst.Arg(keyword=cst.Name("two"), value=cst.Integer("2")), cst.Arg(keyword=cst.Name("three"), value=cst.Integer("3")), ), ), "code": "foo(one = 1, two = 2, three = 3)", "parser": None, "expected_position": None, }, # Keyword arguments parser test { "node": cst.Call( cst.Name("foo"), (cst.Arg( keyword=cst.Name("one"), equal=cst.AssignEqual(), value=cst.Integer("1"), ), ), ), "code": "foo(one = 1)", "parser": parse_expression, "expected_position": None, }, { "node": cst.Call( cst.Name("foo"), ( cst.Arg( keyword=cst.Name("one"), equal=cst.AssignEqual(), value=cst.Integer("1"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( keyword=cst.Name("two"), equal=cst.AssignEqual(), value=cst.Integer("2"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( keyword=cst.Name("three"), equal=cst.AssignEqual(), value=cst.Integer("3"), ), ), ), "code": "foo(one = 1, two = 2, three = 3)", "parser": parse_expression, "expected_position": None, }, # Iterator expansion render test { "node": cst.Call(cst.Name("foo"), (cst.Arg(star="*", value=cst.Name("one")), )), "code": "foo(*one)", "parser": None, "expected_position": None, }, { "node": cst.Call( cst.Name("foo"), ( cst.Arg(star="*", value=cst.Name("one")), cst.Arg(star="*", value=cst.Name("two")), cst.Arg(star="*", value=cst.Name("three")), ), ), "code": "foo(*one, *two, *three)", "parser": None, "expected_position": None, }, # Iterator expansion parser test { "node": cst.Call(cst.Name("foo"), (cst.Arg(star="*", value=cst.Name("one")), )), "code": "foo(*one)", "parser": parse_expression, "expected_position": None, }, { "node": cst.Call( cst.Name("foo"), ( cst.Arg( star="*", value=cst.Name("one"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( star="*", value=cst.Name("two"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg(star="*", value=cst.Name("three")), ), ), "code": "foo(*one, *two, *three)", "parser": parse_expression, "expected_position": None, }, # Dictionary expansion render test { "node": cst.Call(cst.Name("foo"), (cst.Arg(star="**", value=cst.Name("one")), )), "code": "foo(**one)", "parser": None, "expected_position": None, }, { "node": cst.Call( cst.Name("foo"), ( cst.Arg(star="**", value=cst.Name("one")), cst.Arg(star="**", value=cst.Name("two")), cst.Arg(star="**", value=cst.Name("three")), ), ), "code": "foo(**one, **two, **three)", "parser": None, "expected_position": None, }, # Dictionary expansion parser test { "node": cst.Call(cst.Name("foo"), (cst.Arg(star="**", value=cst.Name("one")), )), "code": "foo(**one)", "parser": parse_expression, "expected_position": None, }, { "node": cst.Call( cst.Name("foo"), ( cst.Arg( star="**", value=cst.Name("one"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( star="**", value=cst.Name("two"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg(star="**", value=cst.Name("three")), ), ), "code": "foo(**one, **two, **three)", "parser": parse_expression, "expected_position": None, }, # Complicated mingling rules render test { "node": cst.Call( cst.Name("foo"), ( cst.Arg(value=cst.Name("pos1")), cst.Arg(star="*", value=cst.Name("list1")), cst.Arg(value=cst.Name("pos2")), cst.Arg(value=cst.Name("pos3")), cst.Arg(star="*", value=cst.Name("list2")), cst.Arg(value=cst.Name("pos4")), cst.Arg(star="*", value=cst.Name("list3")), cst.Arg(keyword=cst.Name("kw1"), value=cst.Integer("1")), cst.Arg(star="*", value=cst.Name("list4")), cst.Arg(keyword=cst.Name("kw2"), value=cst.Integer("2")), cst.Arg(star="*", value=cst.Name("list5")), cst.Arg(keyword=cst.Name("kw3"), value=cst.Integer("3")), cst.Arg(star="**", value=cst.Name("dict1")), cst.Arg(keyword=cst.Name("kw4"), value=cst.Integer("4")), cst.Arg(star="**", value=cst.Name("dict2")), ), ), "code": "foo(pos1, *list1, pos2, pos3, *list2, pos4, *list3, kw1 = 1, *list4, kw2 = 2, *list5, kw3 = 3, **dict1, kw4 = 4, **dict2)", "parser": None, "expected_position": None, }, # Complicated mingling rules parser test { "node": cst.Call( cst.Name("foo"), ( cst.Arg( value=cst.Name("pos1"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( star="*", value=cst.Name("list1"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( value=cst.Name("pos2"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( value=cst.Name("pos3"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( star="*", value=cst.Name("list2"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( value=cst.Name("pos4"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( star="*", value=cst.Name("list3"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( keyword=cst.Name("kw1"), equal=cst.AssignEqual(), value=cst.Integer("1"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( star="*", value=cst.Name("list4"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( keyword=cst.Name("kw2"), equal=cst.AssignEqual(), value=cst.Integer("2"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( star="*", value=cst.Name("list5"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( keyword=cst.Name("kw3"), equal=cst.AssignEqual(), value=cst.Integer("3"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( star="**", value=cst.Name("dict1"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( keyword=cst.Name("kw4"), equal=cst.AssignEqual(), value=cst.Integer("4"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg(star="**", value=cst.Name("dict2")), ), ), "code": "foo(pos1, *list1, pos2, pos3, *list2, pos4, *list3, kw1 = 1, *list4, kw2 = 2, *list5, kw3 = 3, **dict1, kw4 = 4, **dict2)", "parser": parse_expression, "expected_position": None, }, # Test whitespace { "node": cst.Call( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), func=cst.Name("foo"), whitespace_after_func=cst.SimpleWhitespace(" "), whitespace_before_args=cst.SimpleWhitespace(" "), args=( cst.Arg( keyword=None, value=cst.Name("pos1"), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.Arg( star="*", whitespace_after_star=cst.SimpleWhitespace(" "), keyword=None, value=cst.Name("list1"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( keyword=cst.Name("kw1"), equal=cst.AssignEqual( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), value=cst.Integer("1"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( star="**", keyword=None, whitespace_after_star=cst.SimpleWhitespace(" "), value=cst.Name("dict1"), whitespace_after_arg=cst.SimpleWhitespace(" "), ), ), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "code": "( foo ( pos1 , * list1, kw1=1, ** dict1 ) )", "parser": parse_expression, "expected_position": CodeRange((1, 2), (1, 43)), }, # Test args { "node": cst.Arg( star="*", whitespace_after_star=cst.SimpleWhitespace(" "), keyword=None, value=cst.Name("list1"), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), "code": "* list1, ", "parser": None, "expected_position": CodeRange((1, 0), (1, 8)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( # Basic expression parenthesizing tests. { "get_node": lambda: cst.Call(func=cst.Name("foo"), lpar=(cst.LeftParen(), )), "expected_re": "left paren without right paren", }, { "get_node": lambda: cst.Call(func=cst.Name("foo"), rpar=(cst.RightParen(), )), "expected_re": "right paren without left paren", }, # Test that we handle keyword stuff correctly. { "get_node": lambda: cst.Call( func=cst.Name("foo"), args=(cst.Arg(equal=cst.AssignEqual(), value=cst.SimpleString("'baz'")), ), ), "expected_re": "Must have a keyword when specifying an AssignEqual", }, # Test that we separate *, ** and keyword args correctly { "get_node": lambda: cst.Call( func=cst.Name("foo"), args=(cst.Arg( star="*", keyword=cst.Name("bar"), value=cst.SimpleString("'baz'"), ), ), ), "expected_re": "Cannot specify a star and a keyword together", }, # Test for expected star inputs only { "get_node": lambda: cst.Call( func=cst.Name("foo"), # pyre-ignore: Ignore type on 'star' since we're testing behavior # when somebody isn't using a type checker. args=(cst.Arg(star="***", value=cst.SimpleString("'baz'")), ), ), "expected_re": r"Must specify either '', '\*' or '\*\*' for star", }, # Test ordering exceptions { "get_node": lambda: cst.Call( func=cst.Name("foo"), args=( cst.Arg(star="**", value=cst.Name("bar")), cst.Arg(star="*", value=cst.Name("baz")), ), ), "expected_re": "Cannot have iterable argument unpacking after keyword argument unpacking", }, { "get_node": lambda: cst.Call( func=cst.Name("foo"), args=( cst.Arg(star="**", value=cst.Name("bar")), cst.Arg(value=cst.Name("baz")), ), ), "expected_re": "Cannot have positional argument after keyword argument unpacking", }, { "get_node": lambda: cst.Call( func=cst.Name("foo"), args=( cst.Arg(keyword=cst.Name("arg"), value=cst.SimpleString("'baz'")), cst.Arg(value=cst.SimpleString("'bar'")), ), ), "expected_re": "Cannot have positional argument after keyword argument", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
def name_to_node(name: str) -> Union[cst.Name, cst.Attribute]: if "." not in name: return cst.Name(name) base, name = name.rsplit(".", 1) return cst.Attribute(value=name_to_node(base), attr=cst.Name(name))
def annotation(self): return cst.Attribute(cst.Name("types"), cst.Name("ModuleType"))