예제 #1
0
    def test_replace_sequence_extract(self) -> None:
        def _reverse_params(
            node: cst.CSTNode,
            extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]],
        ) -> cst.CSTNode:
            return cst.ensure_type(node, cst.FunctionDef).with_changes(
                # pyre-ignore We know "params" is a Sequence[Parameters] but asserting that
                # to pyre is difficult.
                params=cst.Parameters(
                    params=list(reversed(extraction["params"]))))

        # Verify that we can still extract sequences with replace.
        original = cst.parse_module(
            "def bar(baz: int, foo: int, ) -> int:\n    return baz + foo\n")
        replaced = cst.ensure_type(
            m.replace(
                original,
                m.FunctionDef(params=m.Parameters(params=m.SaveMatchedNode(
                    [m.ZeroOrMore(m.Param())], "params"))),
                _reverse_params,
            ),
            cst.Module,
        ).code
        self.assertEqual(
            replaced,
            "def bar(foo: int, baz: int, ) -> int:\n    return baz + foo\n")
예제 #2
0
class MakeModalCommand(VisitorBasedCodemodCommand):

    DESCRIPTION: str = "Replace built-in method MAkeModal with helper"

    method_matcher = matchers.FunctionDef(
        name=matchers.Name(value="MakeModal"),
        params=matchers.Parameters(params=[
            matchers.Param(name=matchers.Name(value="self")),
            matchers.ZeroOrMore()
        ]),
    )
    call_matcher = matchers.Call(
        func=matchers.Attribute(value=matchers.Name(value="self"),
                                attr=matchers.Name(value="MakeModal")))

    method_cst = cst.parse_statement(
        textwrap.dedent("""
            def MakeModal(self, modal=True):
                if modal and not hasattr(self, '_disabler'):
                    self._disabler = wx.WindowDisabler(self)
                if not modal and hasattr(self, '_disabler'):
                    del self._disabler
            """))

    def __init__(self, context: CodemodContext):
        super().__init__(context)

        self.stack: List[cst.ClassDef] = []

    def visit_ClassDef(self, node: cst.ClassDef) -> None:
        self.stack.append(node)

    def leave_ClassDef(self, original_node: cst.ClassDef,
                       updated_node: cst.ClassDef) -> cst.ClassDef:
        return self.stack.pop()

    def leave_Call(self, original_node: cst.Call,
                   updated_node: cst.Call) -> cst.Call:
        if matchers.matches(updated_node, self.call_matcher):
            # Search for MakeModal() method
            current_class = self.stack[-1]
            has_make_modal_method = False

            for method in current_class.body.body:
                if matchers.matches(method, self.method_matcher):
                    has_make_modal_method = True

            # If not, add it to the current class
            if not has_make_modal_method:
                current_class = current_class.with_changes(
                    body=current_class.body.with_changes(
                        body=[*current_class.body.body, self.method_cst]))

                self.stack[-1] = current_class

        return updated_node
        class TestVisitor(MatcherDecoratableTransformer):
            def __init__(self) -> None:
                super().__init__()
                self.visits: List[str] = []

            @call_if_inside(
                m.FunctionDef(m.Name("foo"),
                              params=m.Parameters([m.ZeroOrMore()])))
            def visit_SimpleString(self, node: cst.SimpleString) -> None:
                self.visits.append(node.value)