class TestVisitor(MatcherDecoratableTransformer): def __init__(self) -> None: super().__init__() self.visits: Set[str] = set() self.leaves: Set[str] = set() @visit(m.FunctionDef(m.Name("foo"))) def visit_function1(self, node: cst.FunctionDef) -> None: self.visits.add(node.name.value + "1") @visit(m.FunctionDef(m.Name("foo"))) def visit_function2(self, node: cst.FunctionDef) -> None: self.visits.add(node.name.value + "2") @leave(m.FunctionDef(m.Name("bar"))) def leave_function1( self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: self.leaves.add(updated_node.name.value + "1") return updated_node @leave(m.FunctionDef(m.Name("bar"))) def leave_function2( self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: self.leaves.add(updated_node.name.value + "2") return updated_node
def test_lambda_metadata_matcher(self) -> None: # Match on qualified name provider module = cst.parse_module( "from typing import List\n\ndef foo() -> None: pass\n") wrapper = cst.MetadataWrapper(module) functiondef = cst.ensure_type(wrapper.module.body[1], cst.FunctionDef) self.assertTrue( matches( functiondef, m.FunctionDef(name=m.MatchMetadataIfTrue( meta.QualifiedNameProvider, lambda qualnames: any(n.name in {"foo", "bar", "baz"} for n in qualnames), )), metadata_resolver=wrapper, )) self.assertFalse( matches( functiondef, m.FunctionDef(name=m.MatchMetadataIfTrue( meta.QualifiedNameProvider, lambda qualnames: any(n.name in {"bar", "baz"} for n in qualnames), )), metadata_resolver=wrapper, ))
class TestVisitor(MatcherDecoratableVisitor): def __init__(self) -> None: super().__init__() self.visits: Set[str] = set() self.leaves: Set[str] = set() @call_if_inside(m.FunctionDef(m.Name("foo"))) @visit(m.SimpleString()) def visit_string1(self, node: cst.SimpleString) -> None: self.visits.add(literal_eval(node.value) + "1") @call_if_not_inside(m.FunctionDef(m.Name("bar"))) @visit(m.SimpleString()) def visit_string2(self, node: cst.SimpleString) -> None: self.visits.add(literal_eval(node.value) + "2") @call_if_inside(m.FunctionDef(m.Name("baz"))) @leave(m.SimpleString()) def leave_string1(self, original_node: cst.SimpleString) -> None: self.leaves.add(literal_eval(original_node.value) + "1") @call_if_not_inside(m.FunctionDef(m.Name("foo"))) @leave(m.SimpleString()) def leave_string2(self, original_node: cst.SimpleString) -> None: self.leaves.add(literal_eval(original_node.value) + "2")
class TestVisitor(MatcherDecoratableVisitor): def __init__(self) -> None: super().__init__() self.visits: List[str] = [] self.leaves: List[str] = [] @visit(m.FunctionDef(m.Name("foo") | m.Name("bar"))) def visit_function(self, node: cst.FunctionDef) -> None: self.visits.append(node.name.value) @leave(m.FunctionDef(m.Name("bar") | m.Name("baz"))) def leave_function(self, original_node: cst.FunctionDef) -> None: self.leaves.append(original_node.name.value)
class TestVisitor(MatcherDecoratableTransformer): def __init__(self) -> None: super().__init__() self.visits: List[str] = [] self.leaves: List[str] = [] @call_if_not_inside(m.FunctionDef(m.Name("foo"))) def visit_SimpleString_lpar(self, node: cst.SimpleString) -> None: self.visits.append(node.value) @call_if_not_inside(m.FunctionDef()) def leave_SimpleString_lpar(self, node: cst.SimpleString) -> None: self.leaves.append(node.value)
def test_replace_sequence_extract(self) -> None: def _reverse_params( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return cst.ensure_type(node, cst.FunctionDef).with_changes( # pyre-ignore We know "params" is a Sequence[Parameters] but asserting that # to pyre is difficult. params=cst.Parameters( params=list(reversed(extraction["params"])))) # Verify that we can still extract sequences with replace. original = cst.parse_module( "def bar(baz: int, foo: int, ) -> int:\n return baz + foo\n") replaced = cst.ensure_type( m.replace( original, m.FunctionDef(params=m.Parameters(params=m.SaveMatchedNode( [m.ZeroOrMore(m.Param())], "params"))), _reverse_params, ), cst.Module, ).code self.assertEqual( replaced, "def bar(foo: int, baz: int, ) -> int:\n return baz + foo\n")
class TestVisitor(MatcherDecoratableVisitor): def __init__(self) -> None: super().__init__() self.visits: List[str] = [] @call_if_inside(m.ClassDef(m.Name("A"))) @call_if_inside(m.FunctionDef(m.Name("foo"))) def visit_SimpleString(self, node: cst.SimpleString) -> None: self.visits.append(node.value)
class MakeModalCommand(VisitorBasedCodemodCommand): DESCRIPTION: str = "Replace built-in method MAkeModal with helper" method_matcher = matchers.FunctionDef( name=matchers.Name(value="MakeModal"), params=matchers.Parameters(params=[ matchers.Param(name=matchers.Name(value="self")), matchers.ZeroOrMore() ]), ) call_matcher = matchers.Call( func=matchers.Attribute(value=matchers.Name(value="self"), attr=matchers.Name(value="MakeModal"))) method_cst = cst.parse_statement( textwrap.dedent(""" def MakeModal(self, modal=True): if modal and not hasattr(self, '_disabler'): self._disabler = wx.WindowDisabler(self) if not modal and hasattr(self, '_disabler'): del self._disabler """)) def __init__(self, context: CodemodContext): super().__init__(context) self.stack: List[cst.ClassDef] = [] def visit_ClassDef(self, node: cst.ClassDef) -> None: self.stack.append(node) def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: return self.stack.pop() def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: if matchers.matches(updated_node, self.call_matcher): # Search for MakeModal() method current_class = self.stack[-1] has_make_modal_method = False for method in current_class.body.body: if matchers.matches(method, self.method_matcher): has_make_modal_method = True # If not, add it to the current class if not has_make_modal_method: current_class = current_class.with_changes( body=current_class.body.with_changes( body=[*current_class.body.body, self.method_cst])) self.stack[-1] = current_class return updated_node
class TestVisitor(MatcherDecoratableTransformer): def __init__(self) -> None: super().__init__() self.visits: List[str] = [] @call_if_inside( m.FunctionDef(m.Name("foo"), params=m.Parameters([m.ZeroOrMore()]))) def visit_SimpleString(self, node: cst.SimpleString) -> None: self.visits.append(node.value)
class TestVisitor(MatcherDecoratableTransformer): def __init__(self) -> None: super().__init__() self.func_visits: List[str] = [] self.str_visits: List[str] = [] @call_if_inside(m.FunctionDef(m.Name("foo"))) def visit_SimpleString(self, node: cst.SimpleString) -> None: self.str_visits.append(node.value) def visit_FunctionDef(self, node: cst.FunctionDef) -> None: self.func_visits.append(node.name.value)
def leave_FunctionDef( self, original_node: FunctionDef, updated_node: FunctionDef ) -> Union[BaseStatement, FlattenSentinel[BaseStatement], RemovalSentinel]: if (self.is_visiting_subclass and m.matches( updated_node, m.FunctionDef(name=m.Name("has_add_permission"))) and len(updated_node.params.params) == 2): old_params = updated_node.params updated_params = old_params.with_changes(params=( *old_params.params, parse_param("obj=None"), )) return updated_node.with_changes(params=updated_params) return super().leave_FunctionDef(original_node, updated_node)
class DefaultFunctionReturnTypeCommand(VisitorBasedCodemodCommand): DESCRIPTION = "Adds a default return type of None for functions without a return type" matcher = matchers.FunctionDef(returns=None) def leave_FunctionDef( self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef ) -> cst.FunctionDef: if matchers.matches(updated_node, self.matcher): return updated_node.with_changes(returns=cst.Annotation(cst.Name(value="None"))) return updated_node
def leave_FunctionDef( self, original_node: FunctionDef, updated_node: FunctionDef ) -> Union[BaseStatement, RemovalSentinel]: if (m.matches(updated_node, m.FunctionDef(name=m.Name("has_add_permission"))) and self._is_context_right): if len(updated_node.params.params) == 2: old_params = updated_node.params updated_params = old_params.with_changes(params=( *old_params.params, Param(name=Name("obj"), default=Name("None")), )) return updated_node.with_changes(params=updated_params) return super().leave_FunctionDef(original_node, updated_node)
class TestVisitor(MatcherDecoratableTransformer): @call_if_inside(m.FunctionDef(m.Name("bar"))) @leave(m.SimpleString()) def leave_string1( self, original_node: cst.SimpleString, updated_node: cst.SimpleString) -> cst.SimpleString: return updated_node.with_changes( value=f'"prefix{literal_eval(updated_node.value)}"') @call_if_inside(m.FunctionDef(m.Name("bar"))) @leave(m.SimpleString()) def leave_string2( self, original_node: cst.SimpleString, updated_node: cst.SimpleString) -> cst.SimpleString: return updated_node.with_changes( value=f'"{literal_eval(updated_node.value)}suffix"') @call_if_inside(m.FunctionDef(m.Name("bar"))) def leave_SimpleString( self, original_node: cst.SimpleString, updated_node: cst.SimpleString) -> cst.SimpleString: return updated_node.with_changes( value= f'"{"".join(reversed(literal_eval(updated_node.value)))}"')
def _has_testnode(node: cst.Module) -> bool: return m.matches( node, m.Module(body=[ # Sequence wildcard matchers matches LibCAST nodes in a row in a # sequence. It does not implicitly match on partial sequences. So, # when matching against a sequence we will need to provide a # complete pattern. This often means using helpers such as # ``ZeroOrMore()`` as the first and last element of the sequence. m.ZeroOrMore(), m.AtLeastN( n=1, matcher=m.OneOf( m.FunctionDef(name=m.Name(value=m.MatchIfTrue( lambda value: value.startswith("test_")))), m.ClassDef(name=m.Name(value=m.MatchIfTrue( lambda value: value.startswith("Test")))), ), ), m.ZeroOrMore(), ]), )
def test_complex_matcher_false(self) -> None: # Fail to match since this is a Call, not a FunctionDef. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.FunctionDef(), ) ) # Fail to match a function named "bar". self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(m.Name("bar")), ) ) # Fail to match a function named "foo" with two arguments. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call(func=m.Name("foo"), args=(m.Arg(), m.Arg())), ) ) # Fail to match a function named "foo" with three integer arguments # 3, 2, 1. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=( m.Arg(m.Integer("3")), m.Arg(m.Integer("2")), m.Arg(m.Integer("1")), ), ), ) ) # Fail to match a function named "foo" with three arguments, the last one # being the integer 1. self.assertFalse( matches( cst.Call( func=cst.Name("foo"), args=( cst.Arg(cst.Integer("1")), cst.Arg(cst.Integer("2")), cst.Arg(cst.Integer("3")), ), ), m.Call( func=m.Name("foo"), args=(m.DoNotCare(), m.DoNotCare(), m.Arg(m.Integer("1"))), ), ) )
exc=some_version_of("tornado.gen.Return")) gen_return_call_with_args_matcher = m.Raise(exc=m.Call( func=some_version_of("tornado.gen.Return"), args=[m.AtLeastN(n=1)])) gen_return_call_matcher = m.Raise(exc=m.Call( func=some_version_of("tornado.gen.Return"))) gen_return_matcher = gen_return_statement_matcher | gen_return_call_matcher gen_sleep_matcher = m.Call(func=some_version_of("gen.sleep")) gen_task_matcher = m.Call(func=some_version_of("gen.Task")) gen_coroutine_decorator_matcher = m.Decorator( decorator=some_version_of("tornado.gen.coroutine")) gen_test_coroutine_decorator = m.Decorator( decorator=some_version_of("tornado.testing.gen_test")) coroutine_decorator_matcher = (gen_coroutine_decorator_matcher | gen_test_coroutine_decorator) gen_handler_methods_matcher = m.FunctionDef( name=m.OneOf(m.Name('get_page'), m.Name('post_page'), m.Name('put_page'), m.Name('delete_page'))) coroutine_matcher = (m.FunctionDef( asynchronous=None, decorators=[ m.ZeroOrMore(), (coroutine_decorator_matcher | gen_handler_methods_matcher), m.ZeroOrMore() ], ) | gen_handler_methods_matcher) class TransformError(Exception): """ Error raised upon encountering a known error while attempting to transform the tree.
class Modernizer(m.MatcherDecoratableTransformer): METADATA_DEPENDENCIES = (PositionProvider,) # FIXME use a stack of e.g. SimpleStatementLine then proper visit_Import/ImportFrom to store the ssl node def __init__( self, path: Path, verbose: bool = False, ignored: Optional[List[str]] = None ): super().__init__() self.path = path self.verbose = verbose self.ignored = set(ignored or []) self.errors = False self.stack: List[Tuple[str, ...]] = [] self.annotations: Dict[ Tuple[str, ...], Comment # key: tuple of canonical variable name ] = {} self.python_future_updated_node: Optional[SimpleStatementLine] = None self.python_future_imports: Dict[str, str] = {} self.python_future_new_imports: Set[str] = set() self.builtins_imports: Dict[str, str] = {} self.builtins_new_imports: Set[str] = set() self.builtins_updated_node: Optional[SimpleStatementLine] = None self.future_utils_imports: Dict[str, str] = {} self.future_utils_new_imports: Set[str] = set() self.future_utils_updated_node: Optional[SimpleStatementLine] = None # self.last_import_node: Optional[CSTNode] = None self.last_import_node_stmt: Optional[CSTNode] = None # @m.call_if_inside(m.ImportFrom(module=m.Name("__future__"))) # @m.visit(m.ImportAlias() | m.ImportStar()) # def import_python_future_check(self, node: Union[ImportAlias, ImportStar]) -> None: # self.add_import(self.python_future_imports, node) # @m.leave(m.ImportFrom(module=m.Name("__future__"))) # def import_python_future_modify( # self, original_node: ImportFrom, updated_node: ImportFrom # ) -> Union[BaseSmallStatement, RemovalSentinel]: # return updated_node @m.call_if_inside(m.ImportFrom(module=m.Name("builtins"))) @m.visit(m.ImportAlias() | m.ImportStar()) def import_builtins_check(self, node: Union[ImportAlias, ImportStar]) -> None: self.add_import(self.builtins_imports, node) # @m.leave(m.ImportFrom(module=m.Name("builtins"))) # def builtins_modify( # self, original_node: ImportFrom, updated_node: ImportFrom # ) -> Union[BaseSmallStatement, RemovalSentinel]: # return updated_node @m.call_if_inside( m.ImportFrom(module=m.Attribute(value=m.Name("future"), attr=m.Name("utils"))) ) @m.visit(m.ImportAlias() | m.ImportStar()) def import_future_utils_check(self, node: Union[ImportAlias, ImportStar]) -> None: self.add_import(self.future_utils_imports, node) # @m.leave( # m.ImportFrom(module=m.Attribute(value=m.Name("future"), attr=m.Name("utils"))) # ) # def future_utils_modify( # self, original_node: ImportFrom, updated_node: ImportFrom # ) -> Union[BaseSmallStatement, RemovalSentinel]: # return updated_node @staticmethod def add_import( imports: Dict[str, str], node: Union[ImportAlias, ImportStar] ) -> None: if isinstance(node, ImportAlias): imports[node.name.value] = ( node.asname.name.value if node.asname else node.name.value ) else: imports["*"] = "*" # @m.call_if_not_inside(m.BaseCompoundStatement()) # def visit_Import(self, node: Import) -> Optional[bool]: # self.last_import_node = node # return None # @m.call_if_not_inside(m.BaseCompoundStatement()) # def visit_ImportFrom(self, node: ImportFrom) -> Optional[bool]: # self.last_import_node = node # return None @m.call_if_not_inside(m.ClassDef() | m.FunctionDef() | m.If()) def visit_SimpleStatementLine(self, node: SimpleStatementLine) -> Optional[bool]: for n in node.body: if m.matches(n, m.Import() | m.ImportFrom()): self.last_import_node_stmt = node return None @m.call_if_not_inside(m.ClassDef() | m.FunctionDef() | m.If()) def leave_SimpleStatementLine( self, original_node: SimpleStatementLine, updated_node: SimpleStatementLine ) -> Union[BaseStatement, RemovalSentinel]: for n in updated_node.body: if m.matches(n, m.ImportFrom(module=m.Name("__future__"))): self.python_future_updated_node = updated_node elif m.matches(n, m.ImportFrom(module=m.Name("builtins"))): self.builtins_updated_node = updated_node elif m.matches( n, m.ImportFrom( module=m.Attribute(value=m.Name("future"), attr=m.Name("utils")) ), ): self.future_utils_updated_node = updated_node return updated_node # @m.visit( # m.AllOf( # m.SimpleStatementLine(), # m.MatchIfTrue( # lambda node: any(m.matches(c, m.Assign()) for c in node.children) # ), # m.MatchIfTrue( # lambda node: "# type:" in node.trailing_whitespace.comment.value # ), # ) # ) # def visit_assign(self, node: SimpleStatementSuite) -> None: # return None def visit_Param(self, node: Param) -> Optional[bool]: class Visitor(m.MatcherDecoratableVisitor): def __init__(self): super().__init__() self.ptype: Optional[str] = None def visit_TrailingWhitespace_comment( self, node: "TrailingWhitespace" ) -> None: if node.comment and "type:" in node.comment.value: mo = re.match(r"#\s*type:\s*(\S*)", node.comment.value) self.ptype = mo.group(1) if mo else None return None v = Visitor() node.visit(v) if self.verbose: pos = self.get_metadata(PositionProvider, node).start print( f"{self.path}:{pos.line}:{pos.column}: parameter {node.name.value}: {v.ptype or 'unknown type'}" ) return None @m.visit(m.SimpleStatementLine()) def visit_simple_stmt(self, node: SimpleStatementLine) -> None: assign = None for c in node.children: if m.matches(c, m.Assign()): assign = ensure_type(c, Assign) if assign: if m.MatchIfTrue( lambda n: n.trailing_whitespace.comment and "type:" in n.trailing_whitespace.comment.value ): class TypingVisitor(m.MatcherDecoratableVisitor): def __init__(self): super().__init__() self.vtype = None def visit_TrailingWhitespace_comment( self, node: "TrailingWhitespace" ) -> None: if node.comment: mo = re.match(r"#\s*type:\s*(\S*)", node.comment.value) if mo: vtype = mo.group(1) return None tv = TypingVisitor() node.visit(tv) vtype = tv.vtype else: vtype = None class NameVisitor(m.MatcherDecoratableVisitor): def __init__(self): super().__init__() self.names: List[str] = [] def visit_Name(self, node: Name) -> Optional[bool]: self.names.append(node.value) return None if self.verbose: pos = self.get_metadata(PositionProvider, node).start for target in assign.targets: v = NameVisitor() target.visit(v) for name in v.names: print( f"{self.path}:{pos.line}:{pos.column}: variable {name}: {vtype or 'unknown type'}" ) def visit_FunctionDef_body(self, node: FunctionDef) -> None: class Visitor(m.MatcherDecoratableVisitor): def __init__(self): super().__init__() def visit_EmptyLine_comment(self, node: "EmptyLine") -> None: # FIXME too many matches on test_param_02 if not node.comment: return # TODO: use comment.value return None v = Visitor() node.visit(v) return None map_matcher = m.Call( func=m.Name("filter") | m.Name("map") | m.Name("zip") | m.Name("range") ) @m.visit(map_matcher) def visit_map(self, node: Call) -> None: func_name = ensure_type(node.func, Name).value if func_name not in self.builtins_imports: self.builtins_new_imports.add(func_name) @m.call_if_not_inside( m.Call( func=m.Name("list") | m.Name("set") | m.Name("tuple") | m.Attribute(attr=m.Name("join")) ) | m.CompFor() | m.For() ) @m.leave(map_matcher) def fix_map(self, original_node: Call, updated_node: Call) -> BaseExpression: # TODO test with CompFor etc. # TODO improve join test func_name = ensure_type(updated_node.func, Name).value if func_name not in self.builtins_imports: updated_node = Call(func=Name("list"), args=[Arg(updated_node)]) return updated_node @m.visit(m.Call(func=m.Name("xrange") | m.Name("raw_input"))) def visit_xrange(self, node: Call) -> None: orig_func_name = ensure_type(node.func, Name).value func_name = "range" if orig_func_name == "xrange" else "input" if func_name not in self.builtins_imports: self.builtins_new_imports.add(func_name) @m.leave(m.Call(func=m.Name("xrange") | m.Name("raw_input"))) def fix_xrange(self, original_node: Call, updated_node: Call) -> BaseExpression: orig_func_name = ensure_type(updated_node.func, Name).value func_name = "range" if orig_func_name == "xrange" else "input" return updated_node.with_changes(func=Name(func_name)) iter_matcher = m.Call( func=m.Attribute( attr=m.Name("iterkeys") | m.Name("itervalues") | m.Name("iteritems") ) ) @m.visit(iter_matcher) def visit_iter(self, node: Call) -> None: func_name = ensure_type(node.func, Attribute).attr.value if func_name not in self.future_utils_imports: self.future_utils_new_imports.add(func_name) @m.leave(iter_matcher) def fix_iter(self, original_node: Call, updated_node: Call) -> BaseExpression: attribute = ensure_type(updated_node.func, Attribute) func_name = attribute.attr dict_name = attribute.value return updated_node.with_changes(func=func_name, args=[Arg(dict_name)]) not_iter_matcher = m.Call( func=m.Attribute(attr=m.Name("keys") | m.Name("values") | m.Name("items")) ) @m.call_if_not_inside( m.Call( func=m.Name("list") | m.Name("set") | m.Name("tuple") | m.Attribute(attr=m.Name("join")) ) | m.CompFor() | m.For() ) @m.leave(not_iter_matcher) def fix_not_iter(self, original_node: Call, updated_node: Call) -> BaseExpression: updated_node = Call(func=Name("list"), args=[Arg(updated_node)]) return updated_node @m.call_if_not_inside(m.Import() | m.ImportFrom()) @m.leave(m.Name(value="unicode")) def fix_unicode(self, original_node: Name, updated_node: Name) -> BaseExpression: value = "text_type" if value not in self.future_utils_imports: self.future_utils_new_imports.add(value) return updated_node.with_changes(value=value) def leave_Module(self, original_node: Module, updated_node: Module) -> Module: updated_node = self.update_imports( original_node, updated_node, "builtins", self.builtins_updated_node, self.builtins_imports, self.builtins_new_imports, True, ) updated_node = self.update_imports( original_node, updated_node, "future.utils", self.future_utils_updated_node, self.future_utils_imports, self.future_utils_new_imports, False, ) return updated_node def update_imports( self, original_module: Module, updated_module: Module, import_name: str, updated_import_node: SimpleStatementLine, current_imports: Dict[str, str], new_imports: Set[str], noqa: bool, ) -> Module: if not new_imports: return updated_module noqa_comment = " # noqa" if noqa else "" if not updated_import_node: i = -1 blank_lines = "\n\n" if self.last_import_node_stmt: blank_lines = "" for i, (original, updated) in enumerate( zip(original_module.body, updated_module.body) ): if original is self.last_import_node_stmt: break stmt = parse_module( f"from {import_name} import {', '.join(sorted(new_imports))}{noqa_comment}\n{blank_lines}", config=updated_module.config_for_parsing, ) body = list(updated_module.body) self.last_import_node_stmt = stmt return updated_module.with_changes( body=body[: i + 1] + stmt.children + body[i + 1 :] ) else: if "*" not in current_imports: current_imports_set = { f"{k}" if k == v else f"{k} as {v}" for k, v in current_imports.items() } stmt = parse_statement( f"from {import_name} import {', '.join(sorted(new_imports | current_imports_set))}{noqa_comment}" ) return updated_module.deep_replace(updated_import_node, stmt) # for i, (original, updated) in enumerate( # zip(original_module.body, updated_module.body) # ): # if original is original_import_node: # body = list(updated_module.body) # return updated_module.with_changes( # body=body[:i] + [stmt] + body[i + 1 :] # ) return updated_module
gen_return_call_with_args_matcher = m.Raise(exc=m.Call( func=some_version_of("tornado.gen.Return"), args=[m.AtLeastN(n=1)])) gen_return_call_matcher = m.Raise(exc=m.Call( func=some_version_of("tornado.gen.Return"))) gen_return_matcher = gen_return_statement_matcher | gen_return_call_matcher gen_sleep_matcher = m.Call(func=some_version_of("gen.sleep")) gen_task_matcher = m.Call(func=some_version_of("gen.Task")) gen_coroutine_decorator_matcher = m.Decorator( decorator=some_version_of("tornado.gen.coroutine")) gen_test_coroutine_decorator = m.Decorator( decorator=some_version_of("tornado.testing.gen_test")) coroutine_decorator_matcher = (gen_coroutine_decorator_matcher | gen_test_coroutine_decorator) coroutine_matcher = m.FunctionDef( asynchronous=None, decorators=[m.ZeroOrMore(), coroutine_decorator_matcher, m.ZeroOrMore()], ) class TransformError(Exception): """ Error raised upon encountering a known error while attempting to transform the tree. """ class TornadoAsyncTransformer(cst.CSTTransformer): """ A libcst transformer that replaces the legacy @gen.coroutine/yield async syntax with the python3.7 native async/await syntax.
class ConvertTypeComments(VisitorBasedCodemodCommand): DESCRIPTION = """ Codemod that converts type comments into Python 3.6+ style annotations. Notes: - This transform requires using the `ast` module, which is not compatible with multiprocessing. So you should run using a recent version of python, and set `--jobs=1` if using `python -m libcst.tool codemod ...` from the commandline. - This transform requires capabilities from `ast` that are not available prior to Python 3.9, so libcst must run on Python 3.9+. The code you are transforming can by Python 3.6+, this limitation applies only to libcst itself. We can handle type comments in the following statement types: - Assign - This is converted into a single AnnAssign when possible - In more complicated cases it will produce multiple AnnAssign nodes with no value (i.e. "type declaration" statements) followed by an Assign - For and With - We prepend both of these with type declaration statements. - FunctionDef - We apply all the types we can find. If we find several: - We prefer any existing annotations to type comments - For parameters, we prefer inline type comments to function-level type comments if we find both. We always apply the type comments as quote_annotations annotations, unless we know that it refers to a builtin. We do not guarantee that the resulting string annotations would parse, but they should never cause failures at module import time. We attempt to: - Always strip type comments for statements where we successfully applied types. - Never strip type comments for statements where we failed to apply types. There are many edge case possible where the arity of a type hint (which is either a tuple or a func_type) might not match the code. In these cases we generally give up: - For Assign, For, and With, we require that every target of bindings (e.g. a tuple of names being bound) must have exactly the same arity as the comment. - So, for example, we would skip an assignment statement such as ``x = y, z = 1, 2 # type: int, int`` because the arity of ``x`` does not match the arity of the hint. - For FunctionDef, we do *not* check arity of inline parameter type comments but we do skip the transform if the arity of the function does not match the function-level comment. """ # Finding the location of a type comment in a FunctionDef is difficult. # # As a result, if when visiting a FunctionDef header we are able to # successfully extrct type information then we aggressively strip type # comments until we reach the first statement in the body. # # Once we get there we have to stop, so that we don't unintentionally remove # unprocessed type comments. # # This state handles tracking everything we need for this. function_type_info_stack: List[FunctionTypeInfo] function_body_stack: List[cst.BaseSuite] aggressively_strip_type_comments: bool @staticmethod def add_args(arg_parser: argparse.ArgumentParser) -> None: arg_parser.add_argument( "--no-quote-annotations", action="store_true", help=( "Add unquoted annotations. This leads to prettier code " + "but possibly more errors if type comments are invalid." ), ) def __init__( self, context: CodemodContext, no_quote_annotations: bool = False, ) -> None: if (sys.version_info.major, sys.version_info.minor) < (3, 9): # The ast module did not get `unparse` until Python 3.9, # or `type_comments` until Python 3.8 # # For earlier versions of python, raise early instead of failing # later. It might be possible to use libcst parsing and the # typed_ast library to support earlier python versions, but this is # not a high priority. raise NotImplementedError( "You are trying to run ConvertTypeComments, but libcst " + "needs to be running with Python 3.9+ in order to " + "do this. Try using Python 3.9+ to run your codemod. " + "Note that the target code can be using Python 3.6+, " + "it is only libcst that needs a new Python version." ) super().__init__(context) # flags used to control overall behavior self.quote_annotations: bool = not no_quote_annotations # state used to manage how we traverse nodes in various contexts self.function_type_info_stack = [] self.function_body_stack = [] self.aggressively_strip_type_comments = False def _strip_TrailingWhitespace( self, node: cst.TrailingWhitespace, ) -> cst.TrailingWhitespace: return node.with_changes( whitespace=cst.SimpleWhitespace( "" ), # any whitespace came before the comment, so strip it. comment=None, ) def leave_SimpleStatementLine( self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine, ) -> Union[cst.SimpleStatementLine, cst.FlattenSentinel]: """ Convert any SimpleStatementLine containing an Assign with a type comment into a one that uses a PEP 526 AnnAssign. """ # determine whether to apply an annotation assign = updated_node.body[-1] if not isinstance(assign, cst.Assign): # only Assign matters return updated_node annotation = _annotation_for_statement(original_node) if annotation is None: return updated_node # At this point have a single-line Assign with a type comment. # Convert it to an AnnAssign and strip the comment. converted = convert_Assign( node=assign, annotation=annotation, quote_annotations=self.quote_annotations, ) if isinstance(converted, _FailedToApplyAnnotation): # We were unable to consume the type comment, so return the # original code unchanged. # TODO: allow stripping the invalid type comments via a flag return updated_node elif isinstance(converted, cst.AnnAssign): # We were able to convert the Assign into an AnnAssign, so # we can update the node. return updated_node.with_changes( body=[*updated_node.body[:-1], converted], trailing_whitespace=self._strip_TrailingWhitespace( updated_node.trailing_whitespace, ), ) elif isinstance(converted, list): # We need to inject two or more type declarations. # # In this case, we need to split across multiple lines, and # this also means we'll spread any multi-statement lines out # (multi-statement lines are PEP 8 violating anyway). # # We still preserve leading lines from before our transform. new_statements = [ *( statement.with_changes( semicolon=cst.MaybeSentinel.DEFAULT, ) for statement in updated_node.body[:-1] ), *converted, ] if len(new_statements) < 2: raise RuntimeError("Unreachable code.") return cst.FlattenSentinel( [ updated_node.with_changes( body=[new_statements[0]], trailing_whitespace=self._strip_TrailingWhitespace( updated_node.trailing_whitespace, ), ), *( cst.SimpleStatementLine(body=[statement]) for statement in new_statements[1:] ), ] ) else: raise RuntimeError(f"Unhandled value {converted}") def leave_For( self, original_node: cst.For, updated_node: cst.For, ) -> Union[cst.For, cst.FlattenSentinel]: """ Convert a For with a type hint on the bound variable(s) to use type declarations. """ # Type comments are only possible when the body is an indented # block, and we need this refinement to work with the header, # so we check and only then extract the type comment. body = updated_node.body if not isinstance(body, cst.IndentedBlock): return updated_node annotation = _annotation_for_statement(original_node) if annotation is None: return updated_node # Zip up the type hint and the bindings. If we hit an arity # error, abort. try: type_declarations = AnnotationSpreader.type_declaration_statements( bindings=AnnotationSpreader.unpack_target(updated_node.target), annotations=AnnotationSpreader.unpack_annotation(annotation), leading_lines=updated_node.leading_lines, quote_annotations=self.quote_annotations, ) except _ArityError: return updated_node # There is no arity error, so we can add the type delaration(s) return cst.FlattenSentinel( [ *type_declarations, updated_node.with_changes( body=body.with_changes( header=self._strip_TrailingWhitespace(body.header) ), leading_lines=[], ), ] ) def leave_With( self, original_node: cst.With, updated_node: cst.With, ) -> Union[cst.With, cst.FlattenSentinel]: """ Convert a With with a type hint on the bound variable(s) to use type declarations. """ # Type comments are only possible when the body is an indented # block, and we need this refinement to work with the header, # so we check and only then extract the type comment. body = updated_node.body if not isinstance(body, cst.IndentedBlock): return updated_node annotation = _annotation_for_statement(original_node) if annotation is None: return updated_node # PEP 484 does not attempt to specify type comment semantics for # multiple with bindings (there's more than one sensible way to # do it), so we make no attempt to handle this targets = [ item.asname.name for item in updated_node.items if item.asname is not None ] if len(targets) != 1: return updated_node target = targets[0] # Zip up the type hint and the bindings. If we hit an arity # error, abort. try: type_declarations = AnnotationSpreader.type_declaration_statements( bindings=AnnotationSpreader.unpack_target(target), annotations=AnnotationSpreader.unpack_annotation(annotation), leading_lines=updated_node.leading_lines, quote_annotations=self.quote_annotations, ) except _ArityError: return updated_node # There is no arity error, so we can add the type delaration(s) return cst.FlattenSentinel( [ *type_declarations, updated_node.with_changes( body=body.with_changes( header=self._strip_TrailingWhitespace(body.header) ), leading_lines=[], ), ] ) # Handle function definitions ------------------------- # **Implementation Notes** # # It is much harder to predict where exactly type comments will live # in function definitions than in Assign / For / With. # # As a result, we use two different patterns: # (A) we aggressively strip out type comments from whitespace between the # start of a function define and the start of the body, whenever we were # able to extract type information. This is done via mutable state and the # usual visitor pattern. # (B) we also manually reach down to the first statement inside of the # funciton body and aggressively strip type comments from leading # whitespaces # # PEP 484 underspecifies how to apply type comments to (non-static) # methods - it would be possible to provide a type for `self`, or to omit # it. So we accept either approach when interpreting type comments on # non-static methods: the first argument an have a type provided or not. def _visit_FunctionDef( self, node: cst.FunctionDef, is_method: bool, ) -> None: """ Set up the data we need to handle function definitions: - Parse the type comments. - Store the resulting function type info on the stack, where it will remain until we use it in `leave_FunctionDef` - Set that we are aggressively stripping type comments, which will remain true until we visit the body. """ function_type_info = FunctionTypeInfo.from_cst(node, is_method=is_method) self.aggressively_strip_type_comments = not function_type_info.is_empty() self.function_type_info_stack.append(function_type_info) self.function_body_stack.append(node.body) @m.call_if_not_inside(m.ClassDef()) @m.visit(m.FunctionDef()) def visit_method( self, node: cst.FunctionDef, ) -> None: return self._visit_FunctionDef( node=node, is_method=False, ) @m.call_if_inside(m.ClassDef()) @m.visit(m.FunctionDef()) def visit_function( self, node: cst.FunctionDef, ) -> None: return self._visit_FunctionDef( node=node, is_method=not any( m.matches(d.decorator, m.Name("staticmethod")) for d in node.decorators ), ) def leave_TrailingWhitespace( self, original_node: cst.TrailingWhitespace, updated_node: cst.TrailingWhitespace, ) -> Union[cst.TrailingWhitespace]: "Aggressively remove type comments when in header if we extracted types." if self.aggressively_strip_type_comments and _is_type_comment( updated_node.comment ): return cst.TrailingWhitespace() else: return updated_node def leave_EmptyLine( self, original_node: cst.EmptyLine, updated_node: cst.EmptyLine, ) -> Union[cst.EmptyLine, cst.RemovalSentinel]: "Aggressively remove type comments when in header if we extracted types." if self.aggressively_strip_type_comments and _is_type_comment( updated_node.comment ): return cst.RemovalSentinel.REMOVE else: return updated_node def visit_FunctionDef_body( self, node: cst.FunctionDef, ) -> None: "Turn off aggressive type comment removal when we've leaved the header." self.aggressively_strip_type_comments = False def leave_IndentedBlock( self, original_node: cst.IndentedBlock, updated_node: cst.IndentedBlock, ) -> cst.IndentedBlock: "When appropriate, strip function type comment from the function body." # abort unless this is the body of a function we are transforming if len(self.function_body_stack) == 0: return updated_node if original_node is not self.function_body_stack[-1]: return updated_node if self.function_type_info_stack[-1].is_empty(): return updated_node # The comment will be in the body header if it was on the same line # as the colon. if _is_type_comment(updated_node.header.comment): updated_node = updated_node.with_changes( header=cst.TrailingWhitespace(), ) # The comment will be in a leading line of the first body statement # if it was on the first line after the colon. first_statement = updated_node.body[0] if not hasattr(first_statement, "leading_lines"): return updated_node return updated_node.with_changes( body=[ first_statement.with_changes( leading_lines=[ line # pyre-ignore[16]: we refined via `hasattr` for line in first_statement.leading_lines if not _is_type_comment(line.comment) ] ), *updated_node.body[1:], ] ) # Methods for adding type annotations ---- # # By the time we get here, all type comments should already be stripped. def leave_Param( self, original_node: cst.Param, updated_node: cst.Param, ) -> cst.Param: # ignore type comments if there's already an annotation if updated_node.annotation is not None: return updated_node # find out if there's a type comment and apply it if so function_type_info = self.function_type_info_stack[-1] raw_annotation = function_type_info.arguments.get(updated_node.name.value) if raw_annotation is not None: return updated_node.with_changes( annotation=_convert_annotation( raw=raw_annotation, quote_annotations=self.quote_annotations, ) ) else: return updated_node def leave_FunctionDef( self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef, ) -> cst.FunctionDef: self.function_body_stack.pop() function_type_info = self.function_type_info_stack.pop() if updated_node.returns is None and function_type_info.returns is not None: return updated_node.with_changes( returns=_convert_annotation( raw=function_type_info.returns, quote_annotations=self.quote_annotations, ) ) else: return updated_node def visit_Lambda( self, node: cst.Lambda, ) -> bool: """ Disable traversing under lambdas. They don't have any statements nested inside them so there's no need, and they do have Params which we don't want to transform. """ return False