Exemple #1
0
    def get(
        wrapper: MetadataWrapper,
        original_node: cst.CSTNode,
        replacement_node: Union[cst.CSTNode, cst.RemovalSentinel],
    ) -> "LintPatch":
        # Batch the execution of these position providers
        wrapper.resolve_many(
            [
                ParentNodeProvider,
                ExperimentalReentrantCodegenProvider,
                WhitespaceInclusivePositionProvider,
            ]
        )

        # Use the resolve() API to fetch the data, because it's typed better than
        # resolve_many() is.
        parents = wrapper.resolve(ParentNodeProvider)
        positions = wrapper.resolve(WhitespaceInclusivePositionProvider)
        codegen_partials = wrapper.resolve(ExperimentalReentrantCodegenProvider)

        if isinstance(original_node, cst.Module) and isinstance(
            replacement_node, cst.RemovalSentinel
        ):
            raise Exception("Removing the entire module is not possible")

        # The reentrant codegen provider can only rewrite entire statements at a time,
        # so we need to inspect our parents until find a statement or the module
        possible_statement = original_node
        if isinstance(replacement_node, cst.RemovalSentinel):
            # reentrant codegen doesn't support RemovalSentinel, so use the parent instead
            possible_statement = parents[possible_statement]
        while True:
            if possible_statement in codegen_partials:
                partial = codegen_partials[possible_statement]
                patched_statement = cst.ensure_type(
                    _replace_or_remove(
                        possible_statement, original_node, replacement_node
                    ),
                    cst.BaseStatement,
                )
                original_str = partial.get_original_statement_code()
                patched_str = partial.get_modified_statement_code(patched_statement)
                return LintPatch(
                    partial.start_offset,
                    positions[possible_statement].start,
                    original_str,
                    patched_str,
                )
            elif possible_statement in parents:
                possible_statement = parents[possible_statement]
            else:
                # There's no more parents, so we have to fall back to replacing the whole
                # module.
                original_str = wrapper.module.code
                patched_module = cst.ensure_type(
                    _replace_or_remove(wrapper.module, original_node, replacement_node),
                    cst.Module,
                )
                patched_str = patched_module.code
                return LintPatch(0, CodePosition(1, 0), original_str, patched_str)
    def test_metadata_cache(self) -> None:
        class DummyMetadataProvider(BatchableMetadataProvider[None]):
            gen_cache = tuple

        m = cst.parse_module("pass")
        mw = MetadataWrapper(m)
        with self.assertRaisesRegex(
                Exception,
                "Cache is required for initializing DummyMetadataProvider."):
            mw.resolve(DummyMetadataProvider)

        class SimpleCacheMetadataProvider(BatchableMetadataProvider[object]):
            gen_cache = tuple

            def __init__(self, cache: object) -> None:
                super().__init__(cache)
                self.cache = cache

            def visit_Pass(self, node: cst.Pass) -> Optional[bool]:
                self.set_metadata(node, self.cache)

        cached_data = object()
        # pyre-fixme[6]: Expected `Mapping[Type[BaseMetadataProvider[object]],
        #  object]` for 2nd param but got `Dict[Type[SimpleCacheMetadataProvider],
        #  object]`.
        mw = MetadataWrapper(m,
                             cache={SimpleCacheMetadataProvider: cached_data})
        pass_node = cst.ensure_type(mw.module.body[0],
                                    cst.SimpleStatementLine).body[0]
        self.assertEqual(
            mw.resolve(SimpleCacheMetadataProvider)[pass_node], cached_data)
Exemple #3
0
def get_scope_metadata_provider(
    module_str: str, ) -> Tuple[cst.Module, Mapping[cst.CSTNode, Scope]]:
    wrapper = MetadataWrapper(cst.parse_module(dedent(module_str)))
    return (
        wrapper.module,
        cast(Mapping[cst.CSTNode, Scope], wrapper.resolve(
            ScopeProvider)),  # we're sure every node has an associated scope
    )
Exemple #4
0
    def test_resolve_provider_twice(self) -> None:
        """
        Tests that resolving the same provider twice is a no-op
        """
        mock = Mock()

        class ProviderA(VisitorMetadataProvider[bool]):
            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_a()

        module = cst.parse_module("pass")
        wrapper = MetadataWrapper(module)

        wrapper.resolve(ProviderA)
        mock.visited_a.assert_called_once()

        wrapper.resolve(ProviderA)
        mock.visited_a.assert_called_once()
Exemple #5
0
    def test_function_position(self) -> None:
        wrapper = MetadataWrapper(parse_module("def foo():\n    pass"))
        module = wrapper.module
        positions = wrapper.resolve(PositionProvider)

        fn = cast(cst.FunctionDef, module.body[0])
        stmt = cast(cst.SimpleStatementLine, fn.body.body[0])
        pass_stmt = cast(cst.Pass, stmt.body[0])
        self.cmp_position(positions[stmt], (2, 4), (2, 8))
        self.cmp_position(positions[pass_stmt], (2, 4), (2, 8))
def get_qualified_name_metadata_provider(
    module_str: str
) -> Tuple[cst.Module, Mapping[cst.CSTNode, Collection[QualifiedName]]]:
    wrapper = MetadataWrapper(cst.parse_module(dedent(module_str)))
    return (
        wrapper.module,
        cast(
            Mapping[cst.CSTNode, Collection[QualifiedName]],
            wrapper.resolve(QualifiedNameProvider),
        ),  # we're sure every node has an associated scope
    )
Exemple #7
0
    def test_multiline_string_position(self) -> None:
        wrapper = MetadataWrapper(parse_module('"abc"\\\n"def"'))
        module = wrapper.module
        positions = wrapper.resolve(PositionProvider)

        stmt = cast(cst.SimpleStatementLine, module.body[0])
        expr = cast(cst.Expr, stmt.body[0])
        string = expr.value

        self.cmp_position(positions[stmt], (1, 0), (2, 5))
        self.cmp_position(positions[expr], (1, 0), (2, 5))
        self.cmp_position(positions[string], (1, 0), (2, 5))
 def test_byte_conversion(self, ) -> None:
     module_bytes = "fn()\n".encode("utf-16")
     mw = MetadataWrapper(
         cst.parse_module("fn()\n",
                          cst.PartialParserConfig(encoding="utf-16")))
     codegen_partial = mw.resolve(ExperimentalReentrantCodegenProvider)[
         mw.module.body[0]]
     self.assertEqual(codegen_partial.get_original_module_bytes(),
                      module_bytes)
     self.assertEqual(
         codegen_partial.get_modified_module_bytes(
             cst.parse_statement("fn2()\n")),
         "fn2()\n".encode("utf-16"),
     )
    def test_provider(
        self,
        old_module: str,
        new_module: str,
        old_node: Callable[[cst.Module], cst.CSTNode],
        new_node: cst.BaseStatement,
    ) -> None:
        old_module = dedent(old_module)
        new_module = dedent(new_module)

        mw = MetadataWrapper(cst.parse_module(old_module))
        codegen_partial = mw.resolve(ExperimentalReentrantCodegenProvider)[
            old_node(mw.module)
        ]

        self.assertEqual(codegen_partial.get_original_module_code(), old_module)
        self.assertEqual(codegen_partial.get_modified_module_code(new_node), new_module)
Exemple #10
0
    def test_nested_indent_position(self) -> None:
        wrapper = MetadataWrapper(
            parse_module("if True:\n    if False:\n        x = 1\nelse:\n    return")
        )
        module = wrapper.module
        positions = wrapper.resolve(PositionProvider)

        outer_if = cast(cst.If, module.body[0])
        inner_if = cast(cst.If, outer_if.body.body[0])
        assign = cast(cst.SimpleStatementLine, inner_if.body.body[0]).body[0]

        outer_else = cast(cst.Else, outer_if.orelse)
        return_stmt = cast(cst.SimpleStatementLine, outer_else.body.body[0]).body[0]

        self.cmp_position(positions[outer_if], (1, 0), (5, 10))
        self.cmp_position(positions[inner_if], (2, 4), (3, 13))
        self.cmp_position(positions[assign], (3, 8), (3, 13))
        self.cmp_position(positions[outer_else], (4, 0), (5, 10))
        self.cmp_position(positions[return_stmt], (5, 4), (5, 10))
Exemple #11
0
    def test_module_position(self, *, code: str, expected: CodeRange) -> None:
        wrapper = MetadataWrapper(parse_module(code))
        positions = wrapper.resolve(PositionProvider)

        self.assertEqual(positions[wrapper.module], expected)
Exemple #12
0
def get_qualified_name_metadata_provider(
    module_str: str,
) -> Tuple[cst.Module, Mapping[cst.CSTNode, Collection[QualifiedName]]]:
    wrapper = MetadataWrapper(cst.parse_module(dedent(module_str)))
    return wrapper.module, wrapper.resolve(QualifiedNameProvider)