def test_accesses(self) -> None:
        m, scopes = get_scope_metadata_provider("""
            foo = 'toplevel'
            fn1(foo)
            fn2(foo)
            def fn_def():
                foo = 'shadow'
                fn3(foo)
            """)
        scope_of_module = scopes[m]
        self.assertIsInstance(scope_of_module, GlobalScope)
        global_foo_assignments = scope_of_module["foo"]
        self.assertEqual(len(global_foo_assignments), 1)
        foo_assignment = global_foo_assignments[0]
        self.assertEqual(len(foo_assignment.accesses), 2)
        fn1_call_arg = ensure_type(
            ensure_type(
                ensure_type(m.body[1], cst.SimpleStatementLine).body[0],
                cst.Expr).value,
            cst.Call,
        ).args[0]
        self.assertEqual(foo_assignment.accesses[0].node, fn1_call_arg.value)
        fn2_call_arg = ensure_type(
            ensure_type(
                ensure_type(m.body[2], cst.SimpleStatementLine).body[0],
                cst.Expr).value,
            cst.Call,
        ).args[0]
        self.assertEqual(foo_assignment.accesses[1].node, fn2_call_arg.value)
        func_body = ensure_type(m.body[3], cst.FunctionDef).body
        func_foo_statement = func_body.body[0]
        scope_of_func_statement = scopes[func_foo_statement]
        self.assertIsInstance(scope_of_func_statement, FunctionScope)
        func_foo_assignments = scope_of_func_statement["foo"]
        self.assertEqual(len(func_foo_assignments), 1)
        foo_assignment = func_foo_assignments[0]
        self.assertEqual(len(foo_assignment.accesses), 1)
        fn3_call_arg = ensure_type(
            ensure_type(
                ensure_type(func_body.body[1],
                            cst.SimpleStatementLine).body[0],
                cst.Expr,
            ).value,
            cst.Call,
        ).args[0]
        self.assertEqual(foo_assignment.accesses[0].node, fn3_call_arg.value)

        wrapper = MetadataWrapper(cst.parse_module("from a import b\n"))
        wrapper.visit(DependentVisitor())

        wrapper = MetadataWrapper(
            cst.parse_module("def a():\n    from b import c\n\n"))
        wrapper.visit(DependentVisitor())
    def test_batchable_provider_inherited_metadata(self) -> None:
        """
        Tests that batchable providers inherit access to metadata declared by
        their base classes.
        """
        test_runner = self
        mock = Mock()

        class ProviderA(VisitorMetadataProvider[int]):
            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_a()
                self.set_metadata(node, 1)

        class ProviderB(BatchableMetadataProvider[int]):
            METADATA_DEPENDENCIES = (ProviderA, )

        class ProviderC(ProviderB):
            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_c()
                test_runner.assertEqual(self.get_metadata(ProviderA, node), 1)

        class VisitorA(CSTTransformer):
            METADATA_DEPENDENCIES = (ProviderC, )

        module = parse_module("pass")
        MetadataWrapper(module).visit(VisitorA())

        # Check each visitor is called once
        mock.visited_a.assert_called_once()
        mock.visited_c.assert_called_once()
Exemple #3
0
    def test_batchable_provider(self) -> None:
        class SimpleProvider(BatchableMetadataProvider[int]):
            """
            Sets metadata on every pass node to 1 and every return node to 2.
            """
            def visit_Pass(self, node: cst.Pass) -> None:
                self.set_metadata(node, 1)

            def visit_Return(self, node: cst.Return) -> None:
                self.set_metadata(node, 2)

        wrapper = MetadataWrapper(parse_module("pass; return; pass"))
        module = wrapper.module
        pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0]
        return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1]
        pass_2 = cast(cst.SimpleStatementLine, module.body[0]).body[2]

        provider = SimpleProvider()
        metadata = _gen_batchable(wrapper, [provider])

        # Check access on provider
        self.assertEqual(provider.get_metadata(SimpleProvider, pass_), 1)
        self.assertEqual(provider.get_metadata(SimpleProvider, return_), 2)
        self.assertEqual(provider.get_metadata(SimpleProvider, pass_2), 1)

        # Check returned mapping
        self.assertEqual(metadata[SimpleProvider][pass_], 1)
        self.assertEqual(metadata[SimpleProvider][return_], 2)
        self.assertEqual(metadata[SimpleProvider][pass_2], 1)
    def test_inherited_metadata(self) -> None:
        """
        Tests that classes inherit access to metadata declared by their base
        classes.
        """
        test_runner = self
        mock = Mock()

        class SimpleProvider(VisitorMetadataProvider[int]):
            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_simple()
                self.set_metadata(node, 1)

        class VisitorA(CSTTransformer):
            METADATA_DEPENDENCIES = (SimpleProvider, )

        class VisitorB(VisitorA):
            def visit_Pass(self, node: cst.Pass) -> None:
                test_runner.assertEqual(
                    self.get_metadata(SimpleProvider, node), 1)

        module = parse_module("pass")
        MetadataWrapper(module).visit(VisitorB())

        # Check each visitor is called once
        mock.visited_simple.assert_called_once()
Exemple #5
0
    def test_visitor_provider(self) -> None:
        class SimpleProvider(VisitorMetadataProvider[int]):
            """
            Sets metadata on every node to 1.
            """
            def on_visit(self, node: cst.CSTNode) -> bool:
                self.set_metadata(node, 1)
                return True

        wrapper = MetadataWrapper(parse_module("pass; return"))
        module = wrapper.module
        pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0]
        return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1]

        provider = SimpleProvider()
        metadata = provider._gen(wrapper)

        # Check access on provider
        self.assertEqual(provider.get_metadata(SimpleProvider, module), 1)
        self.assertEqual(provider.get_metadata(SimpleProvider, pass_), 1)
        self.assertEqual(provider.get_metadata(SimpleProvider, return_), 1)

        # Check returned mapping
        self.assertEqual(metadata[module], 1)
        self.assertEqual(metadata[pass_], 1)
        self.assertEqual(metadata[return_], 1)
def get_scope_metadata_provider(
        module_str: str) -> Tuple[cst.Module, Mapping[cst.CSTNode, Scope]]:
    wrapper = MetadataWrapper(cst.parse_module(dedent(module_str)))
    return (
        wrapper.module,
        cast(Mapping[cst.CSTNode, Scope], wrapper.resolve(
            ScopeProvider)),  # we're sure every node has an associated scope
    )
 def test_simple_assign(self) -> None:
     wrapper = MetadataWrapper(parse_module("a = b"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.STORE,
                 "b": ExpressionContext.LOAD,
             },
         ))
 def test_with_as(self) -> None:
     wrapper = MetadataWrapper(parse_module("with a() as b:\n    pass"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.LOAD,
                 "b": ExpressionContext.STORE,
             },
         ))
 def test_starred_element_with_assign(self) -> None:
     wrapper = MetadataWrapper(parse_module("*a = b"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.LOAD,
                 "b": ExpressionContext.LOAD,
             },
             starred_element_to_context={"a": ExpressionContext.STORE},
         ))
 def test_del_with_subscript(self) -> None:
     wrapper = MetadataWrapper(parse_module("del a[b]"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.LOAD,
                 "b": ExpressionContext.LOAD,
             },
             subscript_to_context={"a": ExpressionContext.DEL},
         ))
 def test_del_with_tuple(self) -> None:
     wrapper = MetadataWrapper(parse_module("del a, b"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.DEL,
                 "b": ExpressionContext.DEL,
             },
             tuple_to_context={("a", "b"): ExpressionContext.DEL},
         ))
 def test_except_as(self) -> None:
     wrapper = MetadataWrapper(
         parse_module("try:    ...\nexcept Exception as ex:\n    pass"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "Exception": ExpressionContext.LOAD,
                 "ex": ExpressionContext.STORE,
             },
         ))
    def test_batchable_provider(self) -> None:
        test = self

        class ABatchable(BatchableCSTVisitor):
            METADATA_DEPENDENCIES = (SyntacticPositionProvider,)

            def visit_Pass(self, node: cst.Pass) -> None:
                range = self.get_metadata(SyntacticPositionProvider, node)
                test.assertEqual(range, CodeRange((1, 0), (1, 4)))

        wrapper = MetadataWrapper(parse_module("pass"))
        wrapper.visit_batched([ABatchable()])
 def test_list_with_assing(self) -> None:
     wrapper = MetadataWrapper(parse_module("[a] = [b]"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.STORE,
                 "b": ExpressionContext.LOAD,
             },
             list_to_context={
                 ("a", ): ExpressionContext.STORE,
                 ("b", ): ExpressionContext.LOAD,
             },
         ))
Exemple #15
0
    def get_metadata_wrapper_for_path(self, path: str) -> MetadataWrapper:
        """
        Create a :class:`~libcst.metadata.MetadataWrapper` given a source file path.
        The path needs to be a path relative to project root directory.
        The source code is read and parsed as :class:`~libcst.Module` for
        :class:`~libcst.metadata.MetadataWrapper`.

        .. code-block:: python

            manager = FullRepoManager(".", {"a.py", "b.py"}, {TypeInferenceProvider})
            wrapper = manager.get_metadata_wrapper_for_path("a.py")
        """
        module = cst.parse_module((self.root_path / path).read_text())
        cache = self.get_cache_for_path(path)
        return MetadataWrapper(module, True, cache)
    def test_circular_dependency(self) -> None:
        """
        Tests that circular dependencies are detected.
        """
        class ProviderA(VisitorMetadataProvider[str]):
            pass

        ProviderA.METADATA_DEPENDENCIES = (ProviderA, )

        class BadVisitor(CSTTransformer):
            METADATA_DEPENDENCIES = (ProviderA, )

        with self.assertRaisesRegex(
                MetadataException,
                "Detected circular dependencies in ProviderA"):
            MetadataWrapper(cst.Module([])).visit(BadVisitor())
 def test_assign_to_attribute(self) -> None:
     wrapper = MetadataWrapper(parse_module("a.b = c.d"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.LOAD,
                 "b": ExpressionContext.STORE,
                 "c": ExpressionContext.LOAD,
                 "d": ExpressionContext.LOAD,
             },
             attribute_to_context={
                 "b": ExpressionContext.STORE,
                 "d": ExpressionContext.LOAD,
             },
         ))
 def test_assign_with_subscript(self) -> None:
     wrapper = MetadataWrapper(parse_module("a[b] = c[d]"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.LOAD,
                 "b": ExpressionContext.LOAD,
                 "c": ExpressionContext.LOAD,
                 "d": ExpressionContext.LOAD,
             },
             subscript_to_context={
                 "a": ExpressionContext.STORE,
                 "c": ExpressionContext.LOAD,
             },
         ))
    def test_visitor_provider(self) -> None:
        """
        Sets 2 metadata entries for every node:
            SimpleProvider -> 1
            DependentProvider - > 2
        """
        test = self

        class DependentVisitor(CSTTransformer):
            METADATA_DEPENDENCIES = (SyntacticPositionProvider,)

            def visit_Pass(self, node: cst.Pass) -> None:
                range = self.get_metadata(SyntacticPositionProvider, node)
                test.assertEqual(range, CodeRange((1, 0), (1, 4)))

        wrapper = MetadataWrapper(parse_module("pass"))
        wrapper.visit(DependentVisitor())
Exemple #20
0
 def _find_line_range_for_function_call(
         self, file_contents: str, line_num_1idx: int) -> Tuple[int, int]:
     tree = libcst.parse_module(file_contents)
     function_call_finder = _FunctionCallFinder()
     MetadataWrapper(tree).visit(function_call_finder)
     function_calls_containing_line = [
         (node, node_range)
         for node, node_range in function_call_finder.function_calls
         if node_range.start.line <= line_num_1idx <= node_range.end.line
     ]
     node_range = min(
         function_calls_containing_line,
         key=lambda node_with_range: node_with_range[1].end.line -
         node_with_range[1].start.line,
     )[1]
     start_line_num_0idx_incl = node_range.start.line - 1
     end_line_num_0idx_incl = node_range.end.line - 1
     return (start_line_num_0idx_incl, end_line_num_0idx_incl)
Exemple #21
0
def visit_batched(
        node: CSTNodeT,
        visitors: Iterable[BatchableCSTVisitor],
        before_visit: Optional[VisitorMethod] = None,
        after_leave: Optional[VisitorMethod] = None,
        use_compatible: bool = True,  # TODO: remove this
) -> CSTNodeT:
    """
    Returns the result of running all visitors [visitors] over [node].

    [before_visit] and [after_leave] are provided as optional hooks to
    execute before visit_* and after leave_* methods are executed by the
    batched visitor.
    """
    # TODO: remove compatiblity hack
    if use_compatible:
        from libcst._nodes._module import Module

        if isinstance(node, Module):
            from contextlib import ExitStack
            from libcst.metadata.wrapper import MetadataWrapper

            wrapper = MetadataWrapper(node)
            with ExitStack() as stack:
                # Resolve dependencies of visitors
                for v in visitors:
                    stack.enter_context(v.resolve(wrapper))

                batched_visitor = make_batched(visitors, before_visit,
                                               after_leave)
                return cast(
                    CSTNodeT,
                    wrapper.module.visit(batched_visitor,
                                         use_compatible=False),
                )

        batched_visitor = make_batched(visitors, before_visit, after_leave)
        return cast(CSTNodeT, node.visit(batched_visitor))
    # end compatible

    batched_visitor = make_batched(visitors, before_visit, after_leave)
    return cast(CSTNodeT, node.visit(batched_visitor, use_compatible=False))
    def test_visitor_provider(self) -> None:
        """
        Tests that visitor providers are resolved correctly.

        Sets 2 metadata entries for every node:
            SimpleProvider -> 1
            DependentProvider - > 2
        """

        test = self

        class SimpleProvider(VisitorMetadataProvider[int]):
            def on_visit(self, node: cst.CSTNode) -> bool:
                self.set_metadata(node, 1)
                return True

        class DependentProvider(VisitorMetadataProvider[int]):
            METADATA_DEPENDENCIES = (SimpleProvider, )

            def on_visit(self, node: cst.CSTNode) -> bool:
                self.set_metadata(node,
                                  self.get_metadata(SimpleProvider, node) + 1)
                return True

        class DependentVisitor(CSTTransformer):
            # Declare both providers so the visitor has acesss to both types of metadata
            METADATA_DEPENDENCIES = (DependentProvider, SimpleProvider)

            def visit_Module(self, node: cst.Module) -> None:
                # Check metadata is set
                test.assertEqual(self.get_metadata(SimpleProvider, node), 1)
                test.assertEqual(self.get_metadata(DependentProvider, node), 2)

            def visit_Pass(self, node: cst.Pass) -> None:
                # Check metadata is set
                test.assertEqual(self.get_metadata(SimpleProvider, node), 1)
                test.assertEqual(self.get_metadata(DependentProvider, node), 2)

        module = parse_module("pass")
        MetadataWrapper(module).visit(DependentVisitor())
    def test_undeclared_metadata(self) -> None:
        """
        Tests that access to undeclared metadata throws a key error.
        """
        class ProviderA(VisitorMetadataProvider[bool]):
            pass

        class ProviderB(VisitorMetadataProvider[bool]):
            pass

        class AVisitor(CSTTransformer):
            METADATA_DEPENDENCIES = (ProviderA, )

            def on_visit(self, node: cst.CSTNode) -> bool:
                self.get_metadata(ProviderA, node, True)
                self.get_metadata(ProviderB, node)
                return True

        with self.assertRaisesRegex(
                KeyError,
                "ProviderB is not declared as a dependency from AVisitor"):
            MetadataWrapper(cst.Module([])).visit(AVisitor())
Exemple #24
0
    def visit(self: _ModuleSelfT,
              visitor: CSTVisitorT,
              use_compatible: bool = True) -> _ModuleSelfT:
        """
        Returns the result of running a visitor over this module.

        :class:`Module` overrides the default visitor entry point to resolve metadata
        dependencies declared by 'visitor'.
        """
        # TODO: remove compatibility hack
        if use_compatible:
            from libcst.metadata.wrapper import MetadataWrapper

            wrapper = MetadataWrapper(self)
            result = wrapper.visit(visitor)
        else:
            result = super(Module, self).visit(visitor)

        if isinstance(result, RemovalSentinel):
            return self.with_changes(body=(), header=(), footer=())
        else:  # is a Module
            return cast(_ModuleSelfT, result)
    def test_batched_provider(self) -> None:
        """
        Tests that batchable providers are resolved correctly.

        Sets metadata on:
            - pass: BatchedProviderA -> 1
                    BatchedProviderB -> "a"

        """
        test = self
        mock = Mock()

        class BatchedProviderA(BatchableMetadataProvider[int]):
            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_a()
                self.set_metadata(node, 1)

        class BatchedProviderB(BatchableMetadataProvider[str]):
            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_b()
                self.set_metadata(node, "a")

        class DependentVisitor(CSTTransformer):
            METADATA_DEPENDENCIES = (BatchedProviderA, BatchedProviderB)

            def visit_Pass(self, node: cst.Pass) -> None:
                # Check metadata is set
                test.assertEqual(self.get_metadata(BatchedProviderA, node), 1)
                test.assertEqual(self.get_metadata(BatchedProviderB, node),
                                 "a")

        module = parse_module("pass")
        MetadataWrapper(module).visit(DependentVisitor())

        # Check that each batchable visitor is only called once
        mock.visited_a.assert_called_once()
        mock.visited_b.assert_called_once()
 def test_simple_load(self) -> None:
     wrapper = MetadataWrapper(parse_module("a"))
     wrapper.visit(
         DependentVisitor(test=self,
                          name_to_context={"a": ExpressionContext.LOAD}))
 def test_invalid_type_for_context(self) -> None:
     wrapper = MetadataWrapper(parse_module("a()"))
     wrapper.visit(
         DependentVisitor(test=self,
                          name_to_context={"a": ExpressionContext.LOAD}))
    def test_mixed_providers(self) -> None:
        """
        Tests that a mixed set of providers is resolved properly.

        Sets metadata on pass:
            BatchedProviderA -> 2
            BatchedProviderB -> 3
            DependentProvider -> 5
            DependentBatched -> 4
        """
        test = self
        mock = Mock()

        class SimpleProvider(VisitorMetadataProvider[int]):
            def visit_Pass(self, node: cst.CSTNode) -> None:
                mock.visited_simple()
                self.set_metadata(node, 1)

        class BatchedProviderA(BatchableMetadataProvider[int]):
            METADATA_DEPENDENCIES = (SimpleProvider, )

            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_a()
                self.set_metadata(node, 2)

        class BatchedProviderB(BatchableMetadataProvider[int]):
            METADATA_DEPENDENCIES = (SimpleProvider, )

            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_b()
                self.set_metadata(node, 3)

        class DependentProvider(VisitorMetadataProvider[int]):
            METADATA_DEPENDENCIES = (BatchedProviderA, BatchedProviderB)

            def on_visit(self, node: cst.CSTNode) -> bool:
                sum = self.get_metadata(BatchedProviderA, node,
                                        0) + self.get_metadata(
                                            BatchedProviderB, node, 0)
                self.set_metadata(node, sum)
                return True

        class BatchedProviderC(BatchableMetadataProvider[int]):
            METADATA_DEPENDENCIES = (BatchedProviderA, )

            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_c()
                self.set_metadata(
                    node,
                    self.get_metadata(BatchedProviderA, node) * 2)

        class DependentVisitor(CSTTransformer):
            METADATA_DEPENDENCIES = (
                BatchedProviderA,
                BatchedProviderB,
                BatchedProviderC,
                DependentProvider,
            )

            def visit_Module(self, node: cst.Module) -> None:
                # Dependent visitor set metadata on all nodes but for module it
                # defaulted to 0 because BatchedProviderA/B only set metadata on
                # pass nodes
                test.assertEqual(self.get_metadata(DependentProvider, node), 0)

            def visit_Pass(self, node: cst.Pass) -> None:
                # Check metadata is set
                test.assertEqual(self.get_metadata(BatchedProviderA, node), 2)
                test.assertEqual(self.get_metadata(BatchedProviderB, node), 3)
                test.assertEqual(self.get_metadata(BatchedProviderC, node), 4)
                test.assertEqual(self.get_metadata(DependentProvider, node), 5)

        module = parse_module("pass")
        MetadataWrapper(module).visit(DependentVisitor())

        # Check each visitor is called once
        mock.visited_simple.assert_called_once()
        mock.visited_a.assert_called_once()
        mock.visited_b.assert_called_once()
        mock.visited_c.assert_called_once()