def test_does_not_match_operator_true(self) -> None: # Match on any call that takes one argument that isn't the value None. self.assertTrue( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Name("True")), )), m.Call(args=(m.Arg(value=~(m.Name("None"))), )), )) self.assertTrue( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), )), m.Call(args=(~(m.Arg(m.Name("None"))), )), )) # Match any call that takes an argument which isn't True or False. self.assertTrue( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), )), m.Call(args=(m.Arg( value=~(m.Name("True") | m.Name("False"))), )), )) self.assertTrue( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Name("None")), )), m.Call(args=(m.Arg(value=(~(m.Name("True"))) & (~(m.Name("False")))), )), )) # Roundabout way to verify that or operator works with inverted nodes. self.assertTrue( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Name("False")), )), m.Call(args=(m.Arg(value=(~(m.Name("True"))) | (~(m.Name("True")))), )), )) # Roundabout way to verify that inverse operator works properly on AllOf. self.assertTrue( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), )), m.Call(args=(m.Arg(value=~(m.Name() & m.Name("True"))), )), )) # Match any name node that doesn't match the regex for True self.assertTrue( matches(cst.Name("False"), m.Name(value=~(m.MatchRegex(r"True")))))
def visit_Lambda(self, node: cst.Lambda) -> None: if m.matches( node, m.Lambda( params=m.MatchIfTrue(self._is_simple_parameter_spec), body=m.Call(args=[ m.Arg(value=m.Name(value=param.name.value), star="", keyword=None) for param in node.params.params ]), ), ): call = cst.ensure_type(node.body, cst.Call) full_name = get_full_name_for_node(call) if full_name is None: full_name = "function" self.report( node, UNNECESSARY_LAMBDA.format(function=full_name), replacement=call.func, )
def test_or_matcher_true(self) -> None: # Match on either True or False identifier. self.assertTrue( matches(libcst.Name("True"), m.OneOf(m.Name("True"), m.Name("False")))) # Match any assignment that assigns a value of True or False to an # unspecified target. self.assertTrue( matches( libcst.Assign((libcst.AssignTarget(libcst.Name("x")), ), libcst.Name("True")), m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))), )) self.assertTrue( matches( libcst.Call( libcst.Name("foo"), ( libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2")), libcst.Arg(libcst.Integer("3")), ), ), m.Call( m.Name("foo"), m.OneOf( ( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ( m.Arg(m.Integer("1")), m.Arg(m.Integer("2")), m.Arg(m.Integer("3")), ), ), ), ))
def test_extract_multiple(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Name(), "left"))), m.Element(m.Call(func=m.SaveMatchedNode(m.Name(), "func"))), ]), ) extracted_node_left = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left extracted_node_func = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).func self.assertEqual(nodes, { "left": extracted_node_left, "func": extracted_node_func })
def test_or_matcher_false(self) -> None: # Fail to match since None is not True or False. self.assertFalse( matches(libcst.Name("None"), m.OneOf(m.Name("True"), m.Name("False")))) # Fail to match since assigning None to a target is not the same as # assigning True or False to a target. self.assertFalse( matches( libcst.Assign((libcst.AssignTarget(libcst.Name("x")), ), libcst.Name("None")), m.Assign(value=m.OneOf(m.Name("True"), m.Name("False"))), )) self.assertFalse( matches( libcst.Call( libcst.Name("foo"), ( libcst.Arg(libcst.Integer("1")), libcst.Arg(libcst.Integer("2")), libcst.Arg(libcst.Integer("3")), ), ), m.Call( m.Name("foo"), m.OneOf( ( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ( m.Arg(m.Integer("4")), m.Arg(m.Integer("5")), m.Arg(m.Integer("6")), ), ), ), ))
def collect_targets(self, stack: Tuple[cst.BaseExpression, ...]) -> Tuple[List[cst.BaseExpression], Dict[cst.BaseExpression, List[cst.BaseExpression]]]: targets = {} operands = [] for operand in stack: if m.matches(operand, m.Call(func=m.DoNotCare(), args=[m.Arg(), m.Arg(~m.Tuple())])): call = cst.ensure_type(operand, cst.Call) if not QualifiedNameProvider.has_name(self, call, _ISINSTANCE): operands.append(operand) continue target, match = call.args[0].value, call.args[1].value for possible_target in targets: if target.deep_equals(possible_target): targets[possible_target].append(match) break else: operands.append(target) targets[target] = [match] else: operands.append(operand) return operands, targets
class ToolbarAddToolCommand(VisitorBasedCodemodCommand): DESCRIPTION: str = "Transforms wx.Toolbar.DoAddTool method into AddTool" args_map = {"id": "toolId"} args_matchers_map = { matchers.Arg(keyword=matchers.Name(value=value)): renamed for value, renamed in args_map.items() } call_matcher = matchers.Call( func=matchers.Attribute(attr=matchers.Name(value="DoAddTool")), args=matchers.MatchIfTrue(lambda args: bool( set(arg.keyword.value for arg in args if arg and arg.keyword). intersection(ToolbarAddToolCommand.args_map.keys()))), ) def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: if matchers.matches(updated_node, self.call_matcher): # Update method's call updated_node = updated_node.with_changes( func=updated_node.func.with_changes(attr=cst.Name( value="AddTool"))) # Transform keywords updated_node_args = list(updated_node.args) for arg_matcher, renamed in self.args_matchers_map.items(): for i, node_arg in enumerate(updated_node.args): if matchers.matches(node_arg, arg_matcher): updated_node_args[i] = node_arg.with_changes( keyword=cst.Name(value=renamed)) updated_node = updated_node.with_changes( args=updated_node_args) return updated_node
def __extract_assign_newtype(self, node: cst.Assign): """ Attempts extracting a NewType declaration from the provided Assign node. If the Assign node corresponds to a NewType assignment, the NewType name is added to the class definitions of the Visitor. """ # Define matcher to extract NewType assignment matcher_newtype = match.Assign( targets=[ # Check the assign targets match.AssignTarget( # There should only be one target target=match.Name( # Check target name value=match.SaveMatchedNode( # Save target name match.MatchRegex( r'(.)+'), # Match any string literal "type"))) ], value=match.Call( # We are examining a function call func=match.Name( # Function must have a name value="NewType" # Name must be 'NewType' ), args=[ match.Arg( # Check first argument value=match.SimpleString( ) # First argument must be the name for the type ), match.ZeroOrMore( ) # We allow any number of arguments after by def. of NewType ])) extracted_type = match.extract(node, matcher_newtype) if extracted_type is not None: # Append the additional type to the list # TODO: Either rename class defs, or create new list for additional types self.class_defs.append(extracted_type["type"].strip("\'"))
def leave_Attribute(self, original_node: cst.Attribute, updated_node: cst.Attribute): self.attribute_stack.pop() # x.y. z tail = updated_node.value head = updated_node.attr attrs = split_attribute(tail) # Обфускация метода/поля if m.matches(head, m.Name()): head = cst.ensure_type(head, cst.Name) updated_node = self.obf_var(head, updated_node) elif m.matches(head, m.Call()): head = cst.ensure_type(head, cst.Call) updated_node = self.obf_function_name(head, updated_node) else: pass # Обфускация имени if m.matches(tail, m.Name()): tail = cst.ensure_type(tail, cst.Name) if self.can_rename(tail.value, 'v', 'a', 'ca'): updated_node = updated_node.with_changes( value=self.get_new_cst_name(tail.value)) elif m.matches(tail, m.Subscript()): tail = cst.ensure_type(tail, cst.Subscript) else: pass return updated_node
class HttpRequestXReadLinesTransformer(BaseDjCodemodTransformer): """Replace `HttpRequest.xreadlines()` by iterating over the request.""" deprecated_in = DJANGO_2_0 removed_in = DJANGO_3_0 # This should be conservative and only apply changes to: # - variables called `request`/`req` # - `request`/`req` attributes (e.g `self.request`/`view.req`...) matcher = m.Call(func=m.Attribute( value=m.OneOf( m.Name(value="request"), m.Name(value="req"), m.Attribute(attr=m.Name(value="request")), m.Attribute(attr=m.Name(value="req")), ), attr=m.Name(value="xreadlines"), )) def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression: if m.matches(updated_node, self.matcher): return updated_node.func.value return super().leave_Call(original_node, updated_node)
def test_at_least_n_matcher_args_false(self) -> None: # Fail to match a function call to "foo" where the first argument is the # integer value 1, and there are at least two arguments after that are # strings. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.SimpleString()), n=2), ), ), ) ) # Fail to match a function call to "foo" where the first argument is the integer # value 1, and there are at least three wildcard arguments after. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(), n=3)), ), ) ) # Fail to match a function call to "foo" where the first argument is the # integer value 1, and there are at least two arguements that are integers with # the value 2 after. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.AtLeastN(m.Arg(m.Integer("2")), n=2), ), ), ) )
def test_zero_or_more_matcher_args_true(self) -> None: # Match a function call to "foo" where the first argument is the integer # value 1, and the rest of the arguements are wildcards. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg())), ), ) ) # Match a function call to "foo" where the first argument is the integer # value 1, and the rest of the arguements are integers of any value. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.Integer()))), ), ) ) # Match a function call to "foo" with zero or more arguments, where the # first argument can optionally be the integer 1 or 2, and the second # can only be the integer 2. This case verifies non-greedy behavior in the # matcher. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.ZeroOrMore(m.Arg(m.OneOf(m.Integer("1"), m.Integer("2")))), m.Arg(m.Integer("2")), m.ZeroOrMore(), ), ), ) ) # Match a function call to "foo" where the first argument is the integer # value 1, and the rest of the arguements are integers with the value # 2 or 3. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.ZeroOrMore(m.Arg(m.OneOf(m.Integer("2"), m.Integer("3")))), ), ), ) )
def test_at_least_n_matcher_no_args_true(self) -> None: # Match a function call to "foo" with at least one argument. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.AtLeastN(n=1),)), ) ) # Match a function call to "foo" with at least two arguments. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.AtLeastN(n=2),)), ) ) # Match a function call to "foo" with at least three arguments. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.AtLeastN(n=3),)), ) ) # Match a function call to "foo" with at least two arguments the # first one being the integer 1. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(n=1)) ), ) ) # Match a function call to "foo" with at least three arguments the # first one being the integer 1. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtLeastN(n=2)) ), ) ) # Match a function call to "foo" with at least three arguments. The # There should be an argument with the value 2, which should have # at least one argument before and one argument after. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.AtLeastN(n=1), m.Arg(m.Integer("2")), m.AtLeastN(n=1)), ), ) ) # Match a function call to "foo" with at least two arguments, the last # one being the value 3. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.AtLeastN(n=1), m.Arg(m.Integer("3"))) ), ) ) # Match a function call to "foo" with at least three arguments, the last # one being the value 3. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.AtLeastN(n=2), m.Arg(m.Integer("3"))) ), ) )
class Checker(m.MatcherDecoratableVisitor): METADATA_DEPENDENCIES = (PositionProvider,) def __init__( self, path: Path, verbose: bool = False, ignored: Optional[List[str]] = None ): super().__init__() self.path = path self.verbose = verbose self.ignored = set(ignored or []) self.future_division = False self.errors = False self.stack: List[str] = [] @m.call_if_inside(m.ImportFrom(module=m.Name("__future__"))) @m.visit(m.ImportAlias(name=m.Name("division"))) def import_div(self, node: ImportAlias) -> None: self.future_division = True @m.visit(m.BinaryOperation(operator=m.Divide())) def check_div(self, node: BinaryOperation) -> None: if "division" in self.ignored: return if not self.future_division: pos = self.get_metadata(PositionProvider, node).start print( f"{self.path}:{pos.line}:{pos.column}: division without `from __future__ import division`" ) self.errors = True @m.visit(m.Attribute(attr=m.Name("maxint"), value=m.Name("sys"))) def check_maxint(self, node: Attribute) -> None: if "sys.maxint" in self.ignored: return pos = self.get_metadata(PositionProvider, node).start print(f"{self.path}:{pos.line}:{pos.column}: use of sys.maxint") self.errors = True def visit_ClassDef(self, node: ClassDef) -> None: self.stack.append(node.name.value) def leave_ClassDef(self, node: ClassDef) -> None: self.stack.pop() def visit_FunctionDef(self, node: FunctionDef) -> None: self.stack.append(node.name.value) def leave_FunctionDef(self, node: FunctionDef) -> None: self.stack.pop() def visit_ClassDef_bases(self, node: "ClassDef") -> None: return @m.visit( m.Call( func=m.Attribute(attr=m.Name("assertEquals") | m.Name("assertItemsEqual")) ) ) def visit_old_assert(self, node: Call) -> None: name = ensure_type(node.func, Attribute).attr.value if name in self.ignored: return pos = self.get_metadata(PositionProvider, node).start print(f"{self.path}:{pos.line}:{pos.column}: use of {name}") self.errors = True
from typing import List, Optional, Set, Tuple, Union import libcst as cst from libcst import matchers as m from tornado_async_transformer.helpers import ( name_attr_possibilities, some_version_of, with_added_imports, ) # matchers gen_return_statement_matcher = m.Raise( exc=some_version_of("tornado.gen.Return")) gen_return_call_with_args_matcher = m.Raise(exc=m.Call( func=some_version_of("tornado.gen.Return"), args=[m.AtLeastN(n=1)])) gen_return_call_matcher = m.Raise(exc=m.Call( func=some_version_of("tornado.gen.Return"))) gen_return_matcher = gen_return_statement_matcher | gen_return_call_matcher gen_sleep_matcher = m.Call(func=some_version_of("gen.sleep")) gen_task_matcher = m.Call(func=some_version_of("gen.Task")) gen_coroutine_decorator_matcher = m.Decorator( decorator=some_version_of("tornado.gen.coroutine")) gen_test_coroutine_decorator = m.Decorator( decorator=some_version_of("tornado.testing.gen_test")) coroutine_decorator_matcher = (gen_coroutine_decorator_matcher | gen_test_coroutine_decorator) coroutine_matcher = m.FunctionDef( asynchronous=None, decorators=[m.ZeroOrMore(), coroutine_decorator_matcher, m.ZeroOrMore()],
def is_foreign_key(node: Call) -> bool: return m.matches(node, m.Call(func=m.Attribute(attr=m.Name(value="ForeignKey"))))
def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression: if m.matches(updated_node, m.Call(func=m.Name("url"))): return Call(args=updated_node.args, func=Name("re_path")) return super().leave_Call(original_node, updated_node)
class Modernizer(m.MatcherDecoratableTransformer): METADATA_DEPENDENCIES = (PositionProvider,) # FIXME use a stack of e.g. SimpleStatementLine then proper visit_Import/ImportFrom to store the ssl node def __init__( self, path: Path, verbose: bool = False, ignored: Optional[List[str]] = None ): super().__init__() self.path = path self.verbose = verbose self.ignored = set(ignored or []) self.errors = False self.stack: List[Tuple[str, ...]] = [] self.annotations: Dict[ Tuple[str, ...], Comment # key: tuple of canonical variable name ] = {} self.python_future_updated_node: Optional[SimpleStatementLine] = None self.python_future_imports: Dict[str, str] = {} self.python_future_new_imports: Set[str] = set() self.builtins_imports: Dict[str, str] = {} self.builtins_new_imports: Set[str] = set() self.builtins_updated_node: Optional[SimpleStatementLine] = None self.future_utils_imports: Dict[str, str] = {} self.future_utils_new_imports: Set[str] = set() self.future_utils_updated_node: Optional[SimpleStatementLine] = None # self.last_import_node: Optional[CSTNode] = None self.last_import_node_stmt: Optional[CSTNode] = None # @m.call_if_inside(m.ImportFrom(module=m.Name("__future__"))) # @m.visit(m.ImportAlias() | m.ImportStar()) # def import_python_future_check(self, node: Union[ImportAlias, ImportStar]) -> None: # self.add_import(self.python_future_imports, node) # @m.leave(m.ImportFrom(module=m.Name("__future__"))) # def import_python_future_modify( # self, original_node: ImportFrom, updated_node: ImportFrom # ) -> Union[BaseSmallStatement, RemovalSentinel]: # return updated_node @m.call_if_inside(m.ImportFrom(module=m.Name("builtins"))) @m.visit(m.ImportAlias() | m.ImportStar()) def import_builtins_check(self, node: Union[ImportAlias, ImportStar]) -> None: self.add_import(self.builtins_imports, node) # @m.leave(m.ImportFrom(module=m.Name("builtins"))) # def builtins_modify( # self, original_node: ImportFrom, updated_node: ImportFrom # ) -> Union[BaseSmallStatement, RemovalSentinel]: # return updated_node @m.call_if_inside( m.ImportFrom(module=m.Attribute(value=m.Name("future"), attr=m.Name("utils"))) ) @m.visit(m.ImportAlias() | m.ImportStar()) def import_future_utils_check(self, node: Union[ImportAlias, ImportStar]) -> None: self.add_import(self.future_utils_imports, node) # @m.leave( # m.ImportFrom(module=m.Attribute(value=m.Name("future"), attr=m.Name("utils"))) # ) # def future_utils_modify( # self, original_node: ImportFrom, updated_node: ImportFrom # ) -> Union[BaseSmallStatement, RemovalSentinel]: # return updated_node @staticmethod def add_import( imports: Dict[str, str], node: Union[ImportAlias, ImportStar] ) -> None: if isinstance(node, ImportAlias): imports[node.name.value] = ( node.asname.name.value if node.asname else node.name.value ) else: imports["*"] = "*" # @m.call_if_not_inside(m.BaseCompoundStatement()) # def visit_Import(self, node: Import) -> Optional[bool]: # self.last_import_node = node # return None # @m.call_if_not_inside(m.BaseCompoundStatement()) # def visit_ImportFrom(self, node: ImportFrom) -> Optional[bool]: # self.last_import_node = node # return None @m.call_if_not_inside(m.ClassDef() | m.FunctionDef() | m.If()) def visit_SimpleStatementLine(self, node: SimpleStatementLine) -> Optional[bool]: for n in node.body: if m.matches(n, m.Import() | m.ImportFrom()): self.last_import_node_stmt = node return None @m.call_if_not_inside(m.ClassDef() | m.FunctionDef() | m.If()) def leave_SimpleStatementLine( self, original_node: SimpleStatementLine, updated_node: SimpleStatementLine ) -> Union[BaseStatement, RemovalSentinel]: for n in updated_node.body: if m.matches(n, m.ImportFrom(module=m.Name("__future__"))): self.python_future_updated_node = updated_node elif m.matches(n, m.ImportFrom(module=m.Name("builtins"))): self.builtins_updated_node = updated_node elif m.matches( n, m.ImportFrom( module=m.Attribute(value=m.Name("future"), attr=m.Name("utils")) ), ): self.future_utils_updated_node = updated_node return updated_node # @m.visit( # m.AllOf( # m.SimpleStatementLine(), # m.MatchIfTrue( # lambda node: any(m.matches(c, m.Assign()) for c in node.children) # ), # m.MatchIfTrue( # lambda node: "# type:" in node.trailing_whitespace.comment.value # ), # ) # ) # def visit_assign(self, node: SimpleStatementSuite) -> None: # return None def visit_Param(self, node: Param) -> Optional[bool]: class Visitor(m.MatcherDecoratableVisitor): def __init__(self): super().__init__() self.ptype: Optional[str] = None def visit_TrailingWhitespace_comment( self, node: "TrailingWhitespace" ) -> None: if node.comment and "type:" in node.comment.value: mo = re.match(r"#\s*type:\s*(\S*)", node.comment.value) self.ptype = mo.group(1) if mo else None return None v = Visitor() node.visit(v) if self.verbose: pos = self.get_metadata(PositionProvider, node).start print( f"{self.path}:{pos.line}:{pos.column}: parameter {node.name.value}: {v.ptype or 'unknown type'}" ) return None @m.visit(m.SimpleStatementLine()) def visit_simple_stmt(self, node: SimpleStatementLine) -> None: assign = None for c in node.children: if m.matches(c, m.Assign()): assign = ensure_type(c, Assign) if assign: if m.MatchIfTrue( lambda n: n.trailing_whitespace.comment and "type:" in n.trailing_whitespace.comment.value ): class TypingVisitor(m.MatcherDecoratableVisitor): def __init__(self): super().__init__() self.vtype = None def visit_TrailingWhitespace_comment( self, node: "TrailingWhitespace" ) -> None: if node.comment: mo = re.match(r"#\s*type:\s*(\S*)", node.comment.value) if mo: vtype = mo.group(1) return None tv = TypingVisitor() node.visit(tv) vtype = tv.vtype else: vtype = None class NameVisitor(m.MatcherDecoratableVisitor): def __init__(self): super().__init__() self.names: List[str] = [] def visit_Name(self, node: Name) -> Optional[bool]: self.names.append(node.value) return None if self.verbose: pos = self.get_metadata(PositionProvider, node).start for target in assign.targets: v = NameVisitor() target.visit(v) for name in v.names: print( f"{self.path}:{pos.line}:{pos.column}: variable {name}: {vtype or 'unknown type'}" ) def visit_FunctionDef_body(self, node: FunctionDef) -> None: class Visitor(m.MatcherDecoratableVisitor): def __init__(self): super().__init__() def visit_EmptyLine_comment(self, node: "EmptyLine") -> None: # FIXME too many matches on test_param_02 if not node.comment: return # TODO: use comment.value return None v = Visitor() node.visit(v) return None map_matcher = m.Call( func=m.Name("filter") | m.Name("map") | m.Name("zip") | m.Name("range") ) @m.visit(map_matcher) def visit_map(self, node: Call) -> None: func_name = ensure_type(node.func, Name).value if func_name not in self.builtins_imports: self.builtins_new_imports.add(func_name) @m.call_if_not_inside( m.Call( func=m.Name("list") | m.Name("set") | m.Name("tuple") | m.Attribute(attr=m.Name("join")) ) | m.CompFor() | m.For() ) @m.leave(map_matcher) def fix_map(self, original_node: Call, updated_node: Call) -> BaseExpression: # TODO test with CompFor etc. # TODO improve join test func_name = ensure_type(updated_node.func, Name).value if func_name not in self.builtins_imports: updated_node = Call(func=Name("list"), args=[Arg(updated_node)]) return updated_node @m.visit(m.Call(func=m.Name("xrange") | m.Name("raw_input"))) def visit_xrange(self, node: Call) -> None: orig_func_name = ensure_type(node.func, Name).value func_name = "range" if orig_func_name == "xrange" else "input" if func_name not in self.builtins_imports: self.builtins_new_imports.add(func_name) @m.leave(m.Call(func=m.Name("xrange") | m.Name("raw_input"))) def fix_xrange(self, original_node: Call, updated_node: Call) -> BaseExpression: orig_func_name = ensure_type(updated_node.func, Name).value func_name = "range" if orig_func_name == "xrange" else "input" return updated_node.with_changes(func=Name(func_name)) iter_matcher = m.Call( func=m.Attribute( attr=m.Name("iterkeys") | m.Name("itervalues") | m.Name("iteritems") ) ) @m.visit(iter_matcher) def visit_iter(self, node: Call) -> None: func_name = ensure_type(node.func, Attribute).attr.value if func_name not in self.future_utils_imports: self.future_utils_new_imports.add(func_name) @m.leave(iter_matcher) def fix_iter(self, original_node: Call, updated_node: Call) -> BaseExpression: attribute = ensure_type(updated_node.func, Attribute) func_name = attribute.attr dict_name = attribute.value return updated_node.with_changes(func=func_name, args=[Arg(dict_name)]) not_iter_matcher = m.Call( func=m.Attribute(attr=m.Name("keys") | m.Name("values") | m.Name("items")) ) @m.call_if_not_inside( m.Call( func=m.Name("list") | m.Name("set") | m.Name("tuple") | m.Attribute(attr=m.Name("join")) ) | m.CompFor() | m.For() ) @m.leave(not_iter_matcher) def fix_not_iter(self, original_node: Call, updated_node: Call) -> BaseExpression: updated_node = Call(func=Name("list"), args=[Arg(updated_node)]) return updated_node @m.call_if_not_inside(m.Import() | m.ImportFrom()) @m.leave(m.Name(value="unicode")) def fix_unicode(self, original_node: Name, updated_node: Name) -> BaseExpression: value = "text_type" if value not in self.future_utils_imports: self.future_utils_new_imports.add(value) return updated_node.with_changes(value=value) def leave_Module(self, original_node: Module, updated_node: Module) -> Module: updated_node = self.update_imports( original_node, updated_node, "builtins", self.builtins_updated_node, self.builtins_imports, self.builtins_new_imports, True, ) updated_node = self.update_imports( original_node, updated_node, "future.utils", self.future_utils_updated_node, self.future_utils_imports, self.future_utils_new_imports, False, ) return updated_node def update_imports( self, original_module: Module, updated_module: Module, import_name: str, updated_import_node: SimpleStatementLine, current_imports: Dict[str, str], new_imports: Set[str], noqa: bool, ) -> Module: if not new_imports: return updated_module noqa_comment = " # noqa" if noqa else "" if not updated_import_node: i = -1 blank_lines = "\n\n" if self.last_import_node_stmt: blank_lines = "" for i, (original, updated) in enumerate( zip(original_module.body, updated_module.body) ): if original is self.last_import_node_stmt: break stmt = parse_module( f"from {import_name} import {', '.join(sorted(new_imports))}{noqa_comment}\n{blank_lines}", config=updated_module.config_for_parsing, ) body = list(updated_module.body) self.last_import_node_stmt = stmt return updated_module.with_changes( body=body[: i + 1] + stmt.children + body[i + 1 :] ) else: if "*" not in current_imports: current_imports_set = { f"{k}" if k == v else f"{k} as {v}" for k, v in current_imports.items() } stmt = parse_statement( f"from {import_name} import {', '.join(sorted(new_imports | current_imports_set))}{noqa_comment}" ) return updated_module.deep_replace(updated_import_node, stmt) # for i, (original, updated) in enumerate( # zip(original_module.body, updated_module.body) # ): # if original is original_import_node: # body = list(updated_module.body) # return updated_module.with_changes( # body=body[:i] + [stmt] + body[i + 1 :] # ) return updated_module
def visit_Call(self, node: cst.Call) -> None: if m.matches( node, m.Call( func=m.Name("tuple") | m.Name("list") | m.Name("set") | m.Name("dict"), args=[m.Arg(value=m.List() | m.Tuple())], ), ) or m.matches( node, m.Call(func=m.Name("tuple") | m.Name("list") | m.Name("dict"), args=[]), ): pairs_matcher = m.ZeroOrMore( m.Element(m.Tuple( elements=[m.DoNotCare(), m.DoNotCare()])) | m.Element(m.List( elements=[m.DoNotCare(), m.DoNotCare()]))) exp = cst.ensure_type(node, cst.Call) call_name = cst.ensure_type(exp.func, cst.Name).value # If this is a empty call, it's an Unnecessary Call where we rewrite the call # to literal, except set(). if not exp.args: elements = [] message_formatter = UNNCESSARY_CALL else: arg = exp.args[0].value elements = cst.ensure_type( arg, cst.List if isinstance(arg, cst.List) else cst.Tuple).elements message_formatter = UNNECESSARY_LITERAL if call_name == "tuple": new_node = cst.Tuple(elements=elements) elif call_name == "list": new_node = cst.List(elements=elements) elif call_name == "set": # set() doesn't have an equivelant literal call. If it was # matched here, it's an unnecessary literal suggestion. if len(elements) == 0: self.report( node, UNNECESSARY_LITERAL.format(func=call_name), replacement=node.deep_replace( node, cst.Call(func=cst.Name("set"))), ) return new_node = cst.Set(elements=elements) elif len(elements) == 0 or m.matches( exp.args[0].value, m.Tuple(elements=[pairs_matcher]) | m.List(elements=[pairs_matcher]), ): new_node = cst.Dict(elements=[( lambda val: cst.DictElement(val.elements[ 0].value, val.elements[1].value))(cst.ensure_type( ele.value, cst.Tuple if isinstance(ele.value, cst.Tuple ) else cst.List, )) for ele in elements]) else: # Unrecoginized form return self.report( node, message_formatter.format(func=call_name), replacement=node.deep_replace(node, new_node), )
def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression: if self.is_entity_imported and m.matches( updated_node, m.Call(func=m.Name(self.old_name))): return Name(self.new_name) return super().leave_Call(original_node, updated_node)
class ShedFixers(VisitorBasedCodemodCommand): """Fix a variety of small problems. Replaces `raise NotImplemented` with `raise NotImplementedError`, and converts always-failing assert statements to explicit `raise` statements. Also includes code closely modelled on pybetter's fixers, because it's considerably faster to run all transforms in a single pass if possible. """ DESCRIPTION = "Fix a variety of style, performance, and correctness issues." @m.call_if_inside(m.Raise(exc=m.Name(value="NotImplemented"))) def leave_Name(self, _, updated_node): # noqa return updated_node.with_changes(value="NotImplementedError") def leave_Assert(self, _, updated_node): # noqa test_code = cst.Module("").code_for_node(updated_node.test) try: test_literal = literal_eval(test_code) except Exception: return updated_node if test_literal: return cst.RemovalSentinel.REMOVE if updated_node.msg is None: return cst.Raise(cst.Name("AssertionError")) return cst.Raise( cst.Call(cst.Name("AssertionError"), args=[cst.Arg(updated_node.msg)])) @m.leave( m.ComparisonTarget(comparator=oneof_names("None", "False", "True"), operator=m.Equal())) def convert_none_cmp(self, _, updated_node): """Inspired by Pybetter.""" return updated_node.with_changes(operator=cst.Is()) @m.leave( m.UnaryOperation( operator=m.Not(), expression=m.Comparison( comparisons=[m.ComparisonTarget(operator=m.In())]), )) def replace_not_in_condition(self, _, updated_node): """Also inspired by Pybetter.""" expr = cst.ensure_type(updated_node.expression, cst.Comparison) return cst.Comparison( left=expr.left, lpar=updated_node.lpar, rpar=updated_node.rpar, comparisons=[ expr.comparisons[0].with_changes(operator=cst.NotIn()) ], ) @m.leave( m.Call( lpar=[m.AtLeastN(n=1, matcher=m.LeftParen())], rpar=[m.AtLeastN(n=1, matcher=m.RightParen())], )) def remove_pointless_parens_around_call(self, _, updated_node): # This is *probably* valid, but we might have e.g. a multi-line parenthesised # chain of attribute accesses ("fluent interface"), where we need the parens. noparens = updated_node.with_changes(lpar=[], rpar=[]) try: compile(self.module.code_for_node(noparens), "<string>", "eval") return noparens except SyntaxError: return updated_node # The following methods fix https://pypi.org/project/flake8-comprehensions/ @m.leave(m.Call(func=m.Name("list"), args=[m.Arg(m.GeneratorExp())])) def replace_generator_in_call_with_comprehension(self, _, updated_node): """Fix flake8-comprehensions C400-402 and 403-404. C400-402: Unnecessary generator - rewrite as a <list/set/dict> comprehension. Note that set and dict conversions are handled by pyupgrade! """ return cst.ListComp(elt=updated_node.args[0].value.elt, for_in=updated_node.args[0].value.for_in) @m.leave( m.Call(func=m.Name("list"), args=[m.Arg(m.ListComp(), star="")]) | m.Call(func=m.Name("set"), args=[m.Arg(m.SetComp(), star="")]) | m.Call( func=m.Name("list"), args=[m.Arg(m.Call(func=oneof_names("sorted", "list")), star="")], )) def replace_unnecessary_list_around_sorted(self, _, updated_node): """Fix flake8-comprehensions C411 and C413. Unnecessary <list/reversed> call around sorted(). Also covers C411 Unnecessary list call around list comprehension for lists and sets. """ return updated_node.args[0].value @m.leave( m.Call( func=m.Name("reversed"), args=[m.Arg(m.Call(func=m.Name("sorted")), star="")], )) def replace_unnecessary_reversed_around_sorted(self, _, updated_node): """Fix flake8-comprehensions C413. Unnecessary reversed call around sorted(). """ call = updated_node.args[0].value args = list(call.args) for i, arg in enumerate(args): if m.matches(arg.keyword, m.Name("reverse")): try: val = bool( literal_eval(self.module.code_for_node(arg.value))) except Exception: args[i] = arg.with_changes( value=cst.UnaryOperation(cst.Not(), arg.value)) else: if not val: args[i] = arg.with_changes(value=cst.Name("True")) else: del args[i] args[i - 1] = remove_trailing_comma(args[i - 1]) break else: args.append( cst.Arg(keyword=cst.Name("reverse"), value=cst.Name("True"))) return call.with_changes(args=args) _sets = oneof_names("set", "frozenset") _seqs = oneof_names("list", "reversed", "sorted", "tuple") @m.leave( m.Call(func=_sets, args=[m.Arg(m.Call(func=_sets | _seqs), star="")]) | m.Call( func=oneof_names("list", "tuple"), args=[m.Arg(m.Call(func=oneof_names("list", "tuple")), star="")], ) | m.Call( func=m.Name("sorted"), args=[m.Arg(m.Call(func=_seqs), star=""), m.ZeroOrMore()], )) def replace_unnecessary_nested_calls(self, _, updated_node): """Fix flake8-comprehensions C414. Unnecessary <list/reversed/sorted/tuple> call within <list/set/sorted/tuple>().. """ return updated_node.with_changes( args=[cst.Arg(updated_node.args[0].value.args[0].value)] + list(updated_node.args[1:]), ) @m.leave( m.Call( func=oneof_names("reversed", "set", "sorted"), args=[ m.Arg(m.Subscript(slice=[m.SubscriptElement(ALL_ELEMS_SLICE)])) ], )) def replace_unnecessary_subscript_reversal(self, _, updated_node): """Fix flake8-comprehensions C415. Unnecessary subscript reversal of iterable within <reversed/set/sorted>(). """ return updated_node.with_changes( args=[cst.Arg(updated_node.args[0].value.value)], ) @m.leave( multi( m.ListComp, m.SetComp, elt=m.Name(), for_in=m.CompFor(target=m.Name(), ifs=[], inner_for_in=None, asynchronous=None), )) def replace_unnecessary_listcomp_or_setcomp(self, _, updated_node): """Fix flake8-comprehensions C416. Unnecessary <list/set> comprehension - rewrite using <list/set>(). """ if updated_node.elt.value == updated_node.for_in.target.value: func = cst.Name( "list" if isinstance(updated_node, cst.ListComp) else "set") return cst.Call(func=func, args=[cst.Arg(updated_node.for_in.iter)]) return updated_node @m.leave(m.Subscript(oneof_names("Union", "Literal"))) def reorder_union_literal_contents_none_last(self, _, updated_node): subscript = list(updated_node.slice) try: subscript.sort(key=lambda elt: elt.slice.value.value == "None") subscript[-1] = remove_trailing_comma(subscript[-1]) return updated_node.with_changes(slice=subscript) except Exception: # Single-element literals are not slices, etc. return updated_node @m.call_if_inside(m.Annotation(annotation=m.BinaryOperation())) @m.leave( m.BinaryOperation( left=m.Name("None") | m.BinaryOperation(), operator=m.BitOr(), right=m.DoNotCare(), )) def reorder_union_operator_contents_none_last(self, _, updated_node): def _has_none(node): if m.matches(node, m.Name("None")): return True elif m.matches(node, m.BinaryOperation()): return _has_none(node.left) or _has_none(node.right) else: return False node_left = updated_node.left if _has_none(node_left): return updated_node.with_changes(left=updated_node.right, right=node_left) else: return updated_node @m.leave(m.Subscript(value=m.Name("Literal"))) def flatten_literal_subscript(self, _, updated_node): new_slice = [] for item in updated_node.slice: if m.matches(item.slice.value, m.Subscript(m.Name("Literal"))): new_slice += item.slice.value.slice else: new_slice.append(item) return updated_node.with_changes(slice=new_slice) @m.leave(m.Subscript(value=m.Name("Union"))) def flatten_union_subscript(self, _, updated_node): new_slice = [] has_none = False for item in updated_node.slice: if m.matches(item.slice.value, m.Subscript(m.Name("Optional"))): new_slice += item.slice.value.slice # peel off "Optional" has_none = True elif m.matches(item.slice.value, m.Subscript(m.Name("Union"))) and m.matches( updated_node.value, item.slice.value.value): new_slice += item.slice.value.slice # peel off "Union" or "Literal" elif m.matches(item.slice.value, m.Name("None")): has_none = True else: new_slice.append(item) if has_none: new_slice.append( cst.SubscriptElement(slice=cst.Index(cst.Name("None")))) return updated_node.with_changes(slice=new_slice) @m.leave(m.Else(m.IndentedBlock([m.SimpleStatementLine([m.Pass()])]))) def discard_empty_else_blocks(self, _, updated_node): # An `else: pass` block can always simply be discarded, and libcst ensures # that an Else node can only ever occur attached to an If, While, For, or Try # node; in each case `None` is the valid way to represent "no else block". if m.findall(updated_node, m.Comment()): return updated_node # If there are any comments, keep the node return cst.RemoveFromParent() @m.leave( m.Lambda(params=m.MatchIfTrue(lambda node: ( node.star_kwarg is None and not node.kwonly_params and not node. posonly_params and isinstance(node.star_arg, cst.MaybeSentinel) and all(param.default is None for param in node.params))))) def remove_lambda_indirection(self, _, updated_node): same_args = [ m.Arg(m.Name(param.name.value), star="", keyword=None) for param in updated_node.params.params ] if m.matches(updated_node.body, m.Call(args=same_args)): return cst.ensure_type(updated_node.body, cst.Call).func return updated_node @m.leave( m.BooleanOperation( left=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]), operator=m.Or(), right=m.Call(m.Name("isinstance"), [m.Arg(), m.Arg()]), )) def collapse_isinstance_checks(self, _, updated_node): left_target, left_type = updated_node.left.args right_target, right_type = updated_node.right.args if left_target.deep_equals(right_target): merged_type = cst.Arg( cst.Tuple([ cst.Element(left_type.value), cst.Element(right_type.value) ])) return updated_node.left.with_changes( args=[left_target, merged_type]) return updated_node
def test_at_most_n_matcher_no_args_true(self) -> None: # Match a function call to "foo" with at most two arguments. self.assertTrue( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")),)), m.Call(func=m.Name("foo"), args=(m.AtMostN(n=2),)), ) ) # Match a function call to "foo" with at most two arguments. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call(func=m.Name("foo"), args=(m.AtMostN(n=2),)), ) ) # Match a function call to "foo" with at most six arguments, the last # one being the integer 1. self.assertTrue( matches( cst.Call(func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")),)), m.Call( func=m.Name("foo"), args=[m.AtMostN(n=5), m.Arg(m.Integer("1"))] ), ) ) # Match a function call to "foo" with at most six arguments, the last # one being the integer 1. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call( func=m.Name("foo"), args=(m.AtMostN(n=5), m.Arg(m.Integer("2"))) ), ) ) # Match a function call to "foo" with at most six arguments, the first # one being the integer 1. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.AtMostN(n=5)) ), ) ) # Match a function call to "foo" with at most six arguments, the first # one being the integer 1. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=(cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2"))), ), m.Call(func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrOne())), ) )
def leave_Call( # noqa: C901 self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression: # Lets figure out if this is a "".format() call extraction = self.extract( updated_node, m.Call(func=m.Attribute( value=m.SaveMatchedNode(m.SimpleString(), "string"), attr=m.Name("format"), )), ) if extraction is not None: fstring: List[cst.BaseFormattedStringContent] = [] inserted_sequence: int = 0 stringnode = cst.ensure_type(extraction["string"], cst.SimpleString) tokens = _get_tokens(stringnode.raw_value) for (literal_text, field_name, format_spec, conversion) in tokens: if literal_text: fstring.append(cst.FormattedStringText(literal_text)) if field_name is None: # This is not a format-specification continue if format_spec is not None and len(format_spec) > 0: # TODO: This is supportable since format specs are compatible # with f-string format specs, but it would require matching # format specifier expansions. self.warn( f"Unsupported format_spec {format_spec} in format() call" ) return updated_node # Auto-insert field sequence if it is empty if field_name == "": field_name = str(inserted_sequence) inserted_sequence += 1 expr = _find_expr_from_field_name(field_name, updated_node.args) if expr is None: # Most likely they used * expansion in a format. self.warn( f"Unsupported field_name {field_name} in format() call" ) return updated_node # Verify that we don't have any comments or newlines. Comments aren't # allowed in f-strings, and newlines need parenthesization. We can # have formattedstrings inside other formattedstrings, but I chose not # to doeal with that for now. if self.findall(expr, m.Comment()): # We could strip comments, but this is a formatting change so # we choose not to for now. self.warn(f"Unsupported comment in format() call") return updated_node if self.findall(expr, m.FormattedString()): self.warn(f"Unsupported f-string in format() call") return updated_node if self.findall(expr, m.Await()): # This is fixed in 3.7 but we don't currently have a flag # to enable/disable it. self.warn(f"Unsupported await in format() call") return updated_node # Stripping newlines is effectively a format-only change. expr = cst.ensure_type( expr.visit(StripNewlinesTransformer(self.context)), cst.BaseExpression, ) # Try our best to swap quotes on any strings that won't fit expr = cst.ensure_type( expr.visit( SwitchStringQuotesTransformer(self.context, stringnode.quote[0])), cst.BaseExpression, ) # Verify that the resulting expression doesn't have a backslash # in it. raw_expr_string = self.module.code_for_node(expr) if "\\" in raw_expr_string: self.warn(f"Unsupported backslash in format expression") return updated_node # For safety sake, if this is a dict/set or dict/set comprehension, # wrap it in parens so that it doesn't accidentally create an # escape. if (raw_expr_string.startswith("{") or raw_expr_string.endswith("}")) and (not expr.lpar or not expr.rpar): expr = expr.with_changes(lpar=[cst.LeftParen()], rpar=[cst.RightParen()]) # Verify that any strings we insert don't have the same quote quote_gatherer = StringQuoteGatherer(self.context) expr.visit(quote_gatherer) for stringend in quote_gatherer.stringends: if stringend in stringnode.quote: self.warn( f"Cannot embed string with same quote from format() call" ) return updated_node fstring.append( cst.FormattedStringExpression(expression=expr, conversion=conversion)) return cst.FormattedString( parts=fstring, start=f"f{stringnode.prefix}{stringnode.quote}", end=stringnode.quote, ) return updated_node
def test_complex_matcher_false(self) -> None: # Fail to match since this is a Call, not a FunctionDef. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.FunctionDef(), ) ) # Fail to match a function named "bar". self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(m.Name("bar")), ) ) # Fail to match a function named "foo" with two arguments. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.Arg(), m.Arg())), ) ) # Fail to match a function named "foo" with three integer arguments # 3, 2, 1. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ), ) ) # Fail to match a function named "foo" with three arguments, the last one # being the integer 1. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.DoNotCare(), m.DoNotCare(), m.Arg(m.Integer("1"))), ), ) )
class DeprecationWarningsCommand(VisitorBasedCodemodCommand): DESCRIPTION: str = "Rename deprecated methods" deprecated_symbols_map: List[Tuple[str, Union[str, Tuple[str, str]]]] = [ ("BitmapFromImage", "Bitmap"), ("ImageFromStream", "Image"), ("EmptyIcon", "Icon"), ("DateTimeFromDMY", ("DateTime", "FromDMY")), ] matchers_short_map = { (value, matchers.Call(func=matchers.Name(value=value)), renamed) for value, renamed in deprecated_symbols_map } matchers_full_map = {( matchers.Call(func=matchers.Attribute( value=matchers.Name(value="wx"), attr=matchers.Name(value=value))), renamed, ) for value, renamed in deprecated_symbols_map} def __init__(self, context: CodemodContext): super().__init__(context) self.wx_imports: Set[str] = set() def visit_Module(self, node: cst.Module) -> None: # Collect current list of imports gatherer = GatherImportsVisitor(self.context) node.visit(gatherer) # Store list of symbols imported from wx package self.wx_imports = gatherer.object_mapping.get("wx", set()) def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: # Matches calls with symbols without the wx prefix for symbol, matcher, renamed in self.matchers_short_map: if symbol in self.wx_imports and matchers.matches( updated_node, matcher): # Remove the symbol's import RemoveImportsVisitor.remove_unused_import_by_node( self.context, original_node) # Add import of top level wx package AddImportsVisitor.add_needed_import(self.context, "wx") # Return updated node if isinstance(renamed, tuple): return updated_node.with_changes(func=cst.Attribute( value=cst.Attribute(value=cst.Name(value="wx"), attr=cst.Name(value=renamed[0])), attr=cst.Name(value=renamed[1]), )) return updated_node.with_changes(func=cst.Attribute( value=cst.Name(value="wx"), attr=cst.Name(value=renamed))) # Matches full calls like wx.MySymbol for matcher, renamed in self.matchers_full_map: if matchers.matches(updated_node, matcher): if isinstance(renamed, tuple): return updated_node.with_changes(func=cst.Attribute( value=cst.Attribute(value=cst.Name(value="wx"), attr=cst.Name(value=renamed[0])), attr=cst.Name(value=renamed[1]), )) return updated_node.with_changes( func=updated_node.func.with_changes(attr=cst.Name( value=renamed))) # Returns updated node return updated_node
def test_complex_matcher_true(self) -> None: # Match on any Call, not caring about arguments. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(), ) ) # Match on any Call to a function named "foo". self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(m.Name("foo")), ) ) # Match on any Call to a function named "foo" with three arguments. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.Arg(), m.Arg(), m.Arg())), ) ) # Match any Call to a function named "foo" with three integer arguments. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer()), m.Arg(m.Integer()), m.Arg(m.Integer())), ), ) ) # Match any Call to a function named "foo" with integer arguments 1, 2, 3. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("1")), m.Arg(m.Integer("2")), m.Arg(m.Integer("3")), ), ), ) ) # Match any Call to a function named "foo" with three arguments, the last one # being the integer 3. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.DoNotCare(), m.DoNotCare(), m.Arg(m.Integer("3"))), ), ) )
def is_one_to_one_field(node: Call) -> bool: return m.matches( node, m.Call(func=m.Attribute(attr=m.Name(value="OneToOneField"))), )
def test_zero_or_more_matcher_no_args_true(self) -> None: # Match a function call to "foo" with any number of arguments as # long as the first one is an integer with the value 1. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.Arg(m.Integer("1")), m.ZeroOrMore()) ), ) ) # Match a function call to "foo" with any number of arguments as # long as one of them is an integer with the value 1. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.ZeroOrMore(), m.Arg(m.Integer("1")), m.ZeroOrMore()), ), ) ) # Match a function call to "foo" with any number of arguments as # long as one of them is an integer with the value 2. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.ZeroOrMore(), m.Arg(m.Integer("2")), m.ZeroOrMore()), ), ) ) # Match a function call to "foo" with any number of arguments as # long as one of them is an integer with the value 3. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.ZeroOrMore(), m.Arg(m.Integer("3")), m.ZeroOrMore()), ), ) ) # Match a function call to "foo" with any number of arguments as # long as the last one is an integer with the value 3. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.ZeroOrMore(), m.Arg(m.Integer("3"))) ), ) ) # Match a function call to "foo" with any number of arguments as # long as there are two arguments with the values 1 and 3 anywhere # in the argument list, respecting order. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.ZeroOrMore(), m.Arg(m.Integer("1")), m.ZeroOrMore(), m.Arg(m.Integer("3")), m.ZeroOrMore(), ), ), ) ) # Match a function call to "foo" with any number of arguments as # long as there are three arguments with the values 1, 2 and 3 anywhere # in the argument list, respecting order. self.assertTrue( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.ZeroOrMore(), m.Arg(m.Integer("1")), m.ZeroOrMore(), m.Arg(m.Integer("2")), m.ZeroOrMore(), m.Arg(m.Integer("3")), m.ZeroOrMore(), ), ), ) )
def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression: matcher = self.name_matcher if m.matches(updated_node, m.Call(func=matcher)): return Name(self.new_name) return super().leave_Call(original_node, updated_node)
class DatetimeUtcnow_(VisitorBasedCodemodCommand): DESCRIPTION: str = "Converts from datetime.utcnow() to datetime.utc()" timezone_utc_matcher = m.Arg( value=m.Attribute( value=m.Name(value="timezone"), attr=m.Name(value="utc") ), keyword=m.Name(value="tzinfo"), ) utc_matcher = m.Arg( value=m.OneOf( m.Name(value="utc"), m.Name(value="UTC"), m.Attribute(value=m.Name(value="pytz",), attr=m.Name(value="UTC")), ), keyword=m.Name(value="tzinfo"), ) datetime_utcnow_matcher = m.Call( func=m.Attribute( value=m.Name(value="datetime"), attr=m.Name(value="utcnow") ), args=[], ) datetime_datetime_utcnow_matcher = m.Call( func=m.Attribute( value=m.Attribute( value=m.Name(value="datetime"), attr=m.Name(value="datetime") ), attr=m.Name(value="utcnow"), ), args=[], ) datetime_replace_matcher = m.Call( func=m.Attribute( value=datetime_utcnow_matcher, attr=m.Name(value="replace") ), args=[m.OneOf(timezone_utc_matcher, utc_matcher)], ) datetime_datetime_replace_matcher = m.Call( func=m.Attribute( value=datetime_datetime_utcnow_matcher, attr=m.Name(value="replace"), ), args=[m.OneOf(timezone_utc_matcher, utc_matcher)], ) timedelta_replace_matcher = m.Call( func=m.Attribute( value=m.BinaryOperation( left=m.OneOf( datetime_utcnow_matcher, datetime_datetime_utcnow_matcher ), operator=m.Add(), ), attr=m.Name(value="replace"), ), args=[m.OneOf(timezone_utc_matcher, utc_matcher)], ) utc_localize_matcher = m.Call( func=m.Attribute( value=m.Name(value="UTC"), attr=m.Name(value="localize"), ), args=[ m.Arg( value=m.OneOf( datetime_utcnow_matcher, datetime_datetime_utcnow_matcher ) ) ], ) def _update_imports(self): RemoveImportsVisitor.remove_unused_import(self.context, "pytz") RemoveImportsVisitor.remove_unused_import(self.context, "pytz", "utc") RemoveImportsVisitor.remove_unused_import(self.context, "pytz", "UTC") RemoveImportsVisitor.remove_unused_import( self.context, "datetime", "timezone" ) AddImportsVisitor.add_needed_import( self.context, "bulb.platform.common.timezones", "UTC" ) @m.leave(datetime_utcnow_matcher) def datetime_utcnow_call( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.Call: self._update_imports() return updated_node.with_changes( func=cst.Attribute( value=cst.Name(value="datetime"), attr=cst.Name("now") ), args=[cst.Arg(value=cst.Name(value="UTC"))], ) @m.leave(datetime_datetime_utcnow_matcher) def datetime_datetime_utcnow_call( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.Call: self._update_imports() return updated_node.with_changes( func=cst.Attribute( value=cst.Attribute( value=cst.Name(value="datetime"), attr=cst.Name(value="datetime"), ), attr=cst.Name(value="now"), ), args=[cst.Arg(value=cst.Name(value="UTC"))], ) @m.leave(datetime_replace_matcher) def datetime_replace( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.Call: self._update_imports() return updated_node.with_changes( func=cst.Attribute( value=cst.Name(value="datetime"), attr=cst.Name("now") ), args=[cst.Arg(value=cst.Name(value="UTC"))], ) @m.leave(datetime_datetime_replace_matcher) def datetime_datetime_replace( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.Call: self._update_imports() return updated_node.with_changes( func=cst.Attribute( value=cst.Attribute( value=cst.Name(value="datetime"), attr=cst.Name(value="datetime"), ), attr=cst.Name(value="now"), ), args=[cst.Arg(value=cst.Name(value="UTC"))], ) @m.leave(timedelta_replace_matcher) def timedelta_replace( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.BinaryOperation: self._update_imports() return cast( cst.BinaryOperation, cast(cst.Attribute, cast(cst.Call, updated_node).func).value, ) @m.leave(utc_localize_matcher) def utc_localize( self, original_node: cst.Call, updated_node: cst.Call ) -> cst.Call: self._update_imports() return cast(cst.Call, updated_node.args[0].value)