def test_replace_sequence_extract(self) -> None: def _reverse_params( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return cst.ensure_type(node, cst.FunctionDef).with_changes( # pyre-ignore We know "params" is a Sequence[Parameters] but asserting that # to pyre is difficult. params=cst.Parameters( params=list(reversed(extraction["params"])))) # Verify that we can still extract sequences with replace. original = cst.parse_module( "def bar(baz: int, foo: int, ) -> int:\n return baz + foo\n") replaced = cst.ensure_type( m.replace( original, m.FunctionDef(params=m.Parameters(params=m.SaveMatchedNode( [m.ZeroOrMore(m.Param())], "params"))), _reverse_params, ), cst.Module, ).code self.assertEqual( replaced, "def bar(foo: int, baz: int, ) -> int:\n return baz + foo\n")
class MakeModalCommand(VisitorBasedCodemodCommand): DESCRIPTION: str = "Replace built-in method MAkeModal with helper" method_matcher = matchers.FunctionDef( name=matchers.Name(value="MakeModal"), params=matchers.Parameters(params=[ matchers.Param(name=matchers.Name(value="self")), matchers.ZeroOrMore() ]), ) call_matcher = matchers.Call( func=matchers.Attribute(value=matchers.Name(value="self"), attr=matchers.Name(value="MakeModal"))) method_cst = cst.parse_statement( textwrap.dedent(""" def MakeModal(self, modal=True): if modal and not hasattr(self, '_disabler'): self._disabler = wx.WindowDisabler(self) if not modal and hasattr(self, '_disabler'): del self._disabler """)) def __init__(self, context: CodemodContext): super().__init__(context) self.stack: List[cst.ClassDef] = [] def visit_ClassDef(self, node: cst.ClassDef) -> None: self.stack.append(node) def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: return self.stack.pop() def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Call: if matchers.matches(updated_node, self.call_matcher): # Search for MakeModal() method current_class = self.stack[-1] has_make_modal_method = False for method in current_class.body.body: if matchers.matches(method, self.method_matcher): has_make_modal_method = True # If not, add it to the current class if not has_make_modal_method: current_class = current_class.with_changes( body=current_class.body.with_changes( body=[*current_class.body.body, self.method_cst])) self.stack[-1] = current_class return updated_node
class TestVisitor(MatcherDecoratableTransformer): def __init__(self) -> None: super().__init__() self.visits: List[str] = [] @call_if_inside( m.FunctionDef(m.Name("foo"), params=m.Parameters([m.ZeroOrMore()]))) def visit_SimpleString(self, node: cst.SimpleString) -> None: self.visits.append(node.value)