예제 #1
0
    def get_metadata_wrapper_for_path(self, path: str) -> MetadataWrapper:
        """
        Create a :class:`~libcst.metadata.MetadataWrapper` given a source file path.
        The path needs to be a path relative to project root directory.
        The source code is read and parsed as :class:`~libcst.Module` for
        :class:`~libcst.metadata.MetadataWrapper`.

        .. code-block:: python

            manager = FullRepoManager(".", {"a.py", "b.py"}, {TypeInferenceProvider})
            wrapper = manager.get_metadata_wrapper_for_path("a.py")
        """
        module = cst.parse_module((self.root_path / path).read_text())
        cache = self.get_cache_for_path(path)
        return MetadataWrapper(module, True, cache)
예제 #2
0
    def test_circular_dependency(self) -> None:
        """
        Tests that circular dependencies are detected.
        """
        class ProviderA(VisitorMetadataProvider[str]):
            pass

        ProviderA.METADATA_DEPENDENCIES = (ProviderA, )

        class BadVisitor(CSTTransformer):
            METADATA_DEPENDENCIES = (ProviderA, )

        with self.assertRaisesRegex(
                MetadataException,
                "Detected circular dependencies in ProviderA"):
            MetadataWrapper(cst.Module([])).visit(BadVisitor())
예제 #3
0
 def _find_line_range_for_function_call(
         self, file_contents: str, line_num_1idx: int) -> Tuple[int, int]:
     tree = libcst.parse_module(file_contents)
     function_call_finder = _FunctionCallFinder()
     MetadataWrapper(tree).visit(function_call_finder)
     function_calls_containing_line = [
         (node, node_range)
         for node, node_range in function_call_finder.function_calls
         if node_range.start.line <= line_num_1idx <= node_range.end.line
     ]
     node_range = min(
         function_calls_containing_line,
         key=lambda node_with_range: node_with_range[1].end.line -
         node_with_range[1].start.line,
     )[1]
     start_line_num_0idx_incl = node_range.start.line - 1
     end_line_num_0idx_incl = node_range.end.line - 1
     return (start_line_num_0idx_incl, end_line_num_0idx_incl)
예제 #4
0
def visit_batched(
        node: CSTNodeT,
        visitors: Iterable[BatchableCSTVisitor],
        before_visit: Optional[VisitorMethod] = None,
        after_leave: Optional[VisitorMethod] = None,
        use_compatible: bool = True,  # TODO: remove this
) -> CSTNodeT:
    """
    Returns the result of running all visitors [visitors] over [node].

    [before_visit] and [after_leave] are provided as optional hooks to
    execute before visit_* and after leave_* methods are executed by the
    batched visitor.
    """
    # TODO: remove compatiblity hack
    if use_compatible:
        from libcst._nodes._module import Module

        if isinstance(node, Module):
            from contextlib import ExitStack
            from libcst.metadata.wrapper import MetadataWrapper

            wrapper = MetadataWrapper(node)
            with ExitStack() as stack:
                # Resolve dependencies of visitors
                for v in visitors:
                    stack.enter_context(v.resolve(wrapper))

                batched_visitor = make_batched(visitors, before_visit,
                                               after_leave)
                return cast(
                    CSTNodeT,
                    wrapper.module.visit(batched_visitor,
                                         use_compatible=False),
                )

        batched_visitor = make_batched(visitors, before_visit, after_leave)
        return cast(CSTNodeT, node.visit(batched_visitor))
    # end compatible

    batched_visitor = make_batched(visitors, before_visit, after_leave)
    return cast(CSTNodeT, node.visit(batched_visitor, use_compatible=False))
예제 #5
0
    def test_visitor_provider(self) -> None:
        """
        Tests that visitor providers are resolved correctly.

        Sets 2 metadata entries for every node:
            SimpleProvider -> 1
            DependentProvider - > 2
        """

        test = self

        class SimpleProvider(VisitorMetadataProvider[int]):
            def on_visit(self, node: cst.CSTNode) -> bool:
                self.set_metadata(node, 1)
                return True

        class DependentProvider(VisitorMetadataProvider[int]):
            METADATA_DEPENDENCIES = (SimpleProvider, )

            def on_visit(self, node: cst.CSTNode) -> bool:
                self.set_metadata(node,
                                  self.get_metadata(SimpleProvider, node) + 1)
                return True

        class DependentVisitor(CSTTransformer):
            # Declare both providers so the visitor has acesss to both types of metadata
            METADATA_DEPENDENCIES = (DependentProvider, SimpleProvider)

            def visit_Module(self, node: cst.Module) -> None:
                # Check metadata is set
                test.assertEqual(self.get_metadata(SimpleProvider, node), 1)
                test.assertEqual(self.get_metadata(DependentProvider, node), 2)

            def visit_Pass(self, node: cst.Pass) -> None:
                # Check metadata is set
                test.assertEqual(self.get_metadata(SimpleProvider, node), 1)
                test.assertEqual(self.get_metadata(DependentProvider, node), 2)

        module = parse_module("pass")
        MetadataWrapper(module).visit(DependentVisitor())
예제 #6
0
    def test_undeclared_metadata(self) -> None:
        """
        Tests that access to undeclared metadata throws a key error.
        """
        class ProviderA(VisitorMetadataProvider[bool]):
            pass

        class ProviderB(VisitorMetadataProvider[bool]):
            pass

        class AVisitor(CSTTransformer):
            METADATA_DEPENDENCIES = (ProviderA, )

            def on_visit(self, node: cst.CSTNode) -> bool:
                self.get_metadata(ProviderA, node, True)
                self.get_metadata(ProviderB, node)
                return True

        with self.assertRaisesRegex(
                KeyError,
                "ProviderB is not declared as a dependency from AVisitor"):
            MetadataWrapper(cst.Module([])).visit(AVisitor())
예제 #7
0
    def test_batched_provider(self) -> None:
        """
        Tests that batchable providers are resolved correctly.

        Sets metadata on:
            - pass: BatchedProviderA -> 1
                    BatchedProviderB -> "a"

        """
        test = self
        mock = Mock()

        class BatchedProviderA(BatchableMetadataProvider[int]):
            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_a()
                self.set_metadata(node, 1)

        class BatchedProviderB(BatchableMetadataProvider[str]):
            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_b()
                self.set_metadata(node, "a")

        class DependentVisitor(CSTTransformer):
            METADATA_DEPENDENCIES = (BatchedProviderA, BatchedProviderB)

            def visit_Pass(self, node: cst.Pass) -> None:
                # Check metadata is set
                test.assertEqual(self.get_metadata(BatchedProviderA, node), 1)
                test.assertEqual(self.get_metadata(BatchedProviderB, node),
                                 "a")

        module = parse_module("pass")
        MetadataWrapper(module).visit(DependentVisitor())

        # Check that each batchable visitor is only called once
        mock.visited_a.assert_called_once()
        mock.visited_b.assert_called_once()
 def test_simple_load(self) -> None:
     wrapper = MetadataWrapper(parse_module("a"))
     wrapper.visit(
         DependentVisitor(test=self,
                          name_to_context={"a": ExpressionContext.LOAD}))
 def test_invalid_type_for_context(self) -> None:
     wrapper = MetadataWrapper(parse_module("a()"))
     wrapper.visit(
         DependentVisitor(test=self,
                          name_to_context={"a": ExpressionContext.LOAD}))
예제 #10
0
    def test_mixed_providers(self) -> None:
        """
        Tests that a mixed set of providers is resolved properly.

        Sets metadata on pass:
            BatchedProviderA -> 2
            BatchedProviderB -> 3
            DependentProvider -> 5
            DependentBatched -> 4
        """
        test = self
        mock = Mock()

        class SimpleProvider(VisitorMetadataProvider[int]):
            def visit_Pass(self, node: cst.CSTNode) -> None:
                mock.visited_simple()
                self.set_metadata(node, 1)

        class BatchedProviderA(BatchableMetadataProvider[int]):
            METADATA_DEPENDENCIES = (SimpleProvider, )

            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_a()
                self.set_metadata(node, 2)

        class BatchedProviderB(BatchableMetadataProvider[int]):
            METADATA_DEPENDENCIES = (SimpleProvider, )

            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_b()
                self.set_metadata(node, 3)

        class DependentProvider(VisitorMetadataProvider[int]):
            METADATA_DEPENDENCIES = (BatchedProviderA, BatchedProviderB)

            def on_visit(self, node: cst.CSTNode) -> bool:
                sum = self.get_metadata(BatchedProviderA, node,
                                        0) + self.get_metadata(
                                            BatchedProviderB, node, 0)
                self.set_metadata(node, sum)
                return True

        class BatchedProviderC(BatchableMetadataProvider[int]):
            METADATA_DEPENDENCIES = (BatchedProviderA, )

            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_c()
                self.set_metadata(
                    node,
                    self.get_metadata(BatchedProviderA, node) * 2)

        class DependentVisitor(CSTTransformer):
            METADATA_DEPENDENCIES = (
                BatchedProviderA,
                BatchedProviderB,
                BatchedProviderC,
                DependentProvider,
            )

            def visit_Module(self, node: cst.Module) -> None:
                # Dependent visitor set metadata on all nodes but for module it
                # defaulted to 0 because BatchedProviderA/B only set metadata on
                # pass nodes
                test.assertEqual(self.get_metadata(DependentProvider, node), 0)

            def visit_Pass(self, node: cst.Pass) -> None:
                # Check metadata is set
                test.assertEqual(self.get_metadata(BatchedProviderA, node), 2)
                test.assertEqual(self.get_metadata(BatchedProviderB, node), 3)
                test.assertEqual(self.get_metadata(BatchedProviderC, node), 4)
                test.assertEqual(self.get_metadata(DependentProvider, node), 5)

        module = parse_module("pass")
        MetadataWrapper(module).visit(DependentVisitor())

        # Check each visitor is called once
        mock.visited_simple.assert_called_once()
        mock.visited_a.assert_called_once()
        mock.visited_b.assert_called_once()
        mock.visited_c.assert_called_once()