예제 #1
0
    def get(
        wrapper: MetadataWrapper,
        original_node: cst.CSTNode,
        replacement_node: Union[cst.CSTNode, cst.RemovalSentinel],
    ) -> "LintPatch":
        # Batch the execution of these position providers
        wrapper.resolve_many(
            [
                ParentNodeProvider,
                ExperimentalReentrantCodegenProvider,
                WhitespaceInclusivePositionProvider,
            ]
        )

        # Use the resolve() API to fetch the data, because it's typed better than
        # resolve_many() is.
        parents = wrapper.resolve(ParentNodeProvider)
        positions = wrapper.resolve(WhitespaceInclusivePositionProvider)
        codegen_partials = wrapper.resolve(ExperimentalReentrantCodegenProvider)

        if isinstance(original_node, cst.Module) and isinstance(
            replacement_node, cst.RemovalSentinel
        ):
            raise Exception("Removing the entire module is not possible")

        # The reentrant codegen provider can only rewrite entire statements at a time,
        # so we need to inspect our parents until find a statement or the module
        possible_statement = original_node
        if isinstance(replacement_node, cst.RemovalSentinel):
            # reentrant codegen doesn't support RemovalSentinel, so use the parent instead
            possible_statement = parents[possible_statement]
        while True:
            if possible_statement in codegen_partials:
                partial = codegen_partials[possible_statement]
                patched_statement = cst.ensure_type(
                    _replace_or_remove(
                        possible_statement, original_node, replacement_node
                    ),
                    cst.BaseStatement,
                )
                original_str = partial.get_original_statement_code()
                patched_str = partial.get_modified_statement_code(patched_statement)
                return LintPatch(
                    partial.start_offset,
                    positions[possible_statement].start,
                    original_str,
                    patched_str,
                )
            elif possible_statement in parents:
                possible_statement = parents[possible_statement]
            else:
                # There's no more parents, so we have to fall back to replacing the whole
                # module.
                original_str = wrapper.module.code
                patched_module = cst.ensure_type(
                    _replace_or_remove(wrapper.module, original_node, replacement_node),
                    cst.Module,
                )
                patched_str = patched_module.code
                return LintPatch(0, CodePosition(1, 0), original_str, patched_str)
예제 #2
0
 def test_equality_by_identity(self) -> None:
     m = cst.parse_module("pass")
     mw1 = MetadataWrapper(m)
     mw2 = MetadataWrapper(m)
     self.assertEqual(mw1, mw1)
     self.assertEqual(mw2, mw2)
     self.assertNotEqual(mw1, mw2)
예제 #3
0
def get_scope_metadata_provider(
    module_str: str, ) -> Tuple[cst.Module, Mapping[cst.CSTNode, Scope]]:
    wrapper = MetadataWrapper(cst.parse_module(dedent(module_str)))
    return (
        wrapper.module,
        cast(Mapping[cst.CSTNode, Scope], wrapper.resolve(
            ScopeProvider)),  # we're sure every node has an associated scope
    )
 def test_with_as(self) -> None:
     wrapper = MetadataWrapper(parse_module("with a() as b:\n    pass"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.LOAD,
                 "b": ExpressionContext.STORE,
             },
         ))
 def test_simple_assign(self) -> None:
     wrapper = MetadataWrapper(parse_module("a = b"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.STORE,
                 "b": ExpressionContext.LOAD,
             },
         ))
예제 #6
0
    def test_function_position(self) -> None:
        wrapper = MetadataWrapper(parse_module("def foo():\n    pass"))
        module = wrapper.module
        positions = wrapper.resolve(PositionProvider)

        fn = cast(cst.FunctionDef, module.body[0])
        stmt = cast(cst.SimpleStatementLine, fn.body.body[0])
        pass_stmt = cast(cst.Pass, stmt.body[0])
        self.cmp_position(positions[stmt], (2, 4), (2, 8))
        self.cmp_position(positions[pass_stmt], (2, 4), (2, 8))
 def test_del_with_subscript(self) -> None:
     wrapper = MetadataWrapper(parse_module("del a[b]"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.LOAD,
                 "b": ExpressionContext.LOAD,
             },
             subscript_to_context={"a": ExpressionContext.DEL},
         ))
 def test_starred_element_with_assign(self) -> None:
     wrapper = MetadataWrapper(parse_module("*a = b"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.LOAD,
                 "b": ExpressionContext.LOAD,
             },
             starred_element_to_context={"a": ExpressionContext.STORE},
         ))
def get_qualified_name_metadata_provider(
    module_str: str
) -> Tuple[cst.Module, Mapping[cst.CSTNode, Collection[QualifiedName]]]:
    wrapper = MetadataWrapper(cst.parse_module(dedent(module_str)))
    return (
        wrapper.module,
        cast(
            Mapping[cst.CSTNode, Collection[QualifiedName]],
            wrapper.resolve(QualifiedNameProvider),
        ),  # we're sure every node has an associated scope
    )
 def test_for(self) -> None:
     wrapper = MetadataWrapper(parse_module("for i in items:\n    j = 1"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "i": ExpressionContext.STORE,
                 "items": ExpressionContext.LOAD,
                 "j": ExpressionContext.STORE,
             },
         ))
 def test_del_with_tuple(self) -> None:
     wrapper = MetadataWrapper(parse_module("del a, b"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.DEL,
                 "b": ExpressionContext.DEL,
             },
             tuple_to_context={("a", "b"): ExpressionContext.DEL},
         ))
 def test_except_as(self) -> None:
     wrapper = MetadataWrapper(
         parse_module("try:    ...\nexcept Exception as ex:\n    pass"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "Exception": ExpressionContext.LOAD,
                 "ex": ExpressionContext.STORE,
             },
         ))
예제 #13
0
 def test_hash_by_identity(self) -> None:
     m = cst.parse_module("pass")
     mw1 = MetadataWrapper(m)
     mw2 = MetadataWrapper(m, unsafe_skip_copy=True)
     mw3 = MetadataWrapper(m, unsafe_skip_copy=True)
     self.assertEqual(hash(mw1), hash(mw1))
     self.assertEqual(hash(mw2), hash(mw2))
     self.assertEqual(hash(mw3), hash(mw3))
     self.assertNotEqual(hash(mw1), hash(mw2))
     self.assertNotEqual(hash(mw1), hash(mw3))
     self.assertNotEqual(hash(mw2), hash(mw3))
예제 #14
0
def refactor_string(source, unused_imports):
    try:
        wrapper = MetadataWrapper(cst.parse_module(source))
    except cst.ParserSyntaxError as err:
        print(Color(str(err)).red)
    else:
        if unused_imports:
            fixed_module = wrapper.visit(
                RemoveUnusedImportTransformer(unused_imports))
            return fixed_module.code
    return source
예제 #15
0
    def test_batchable_provider(self) -> None:
        test = self

        class ABatchable(BatchableCSTVisitor):
            METADATA_DEPENDENCIES = (PositionProvider, )

            def visit_Pass(self, node: cst.Pass) -> None:
                test.assertEqual(self.get_metadata(PositionProvider, node),
                                 CodeRange((1, 0), (1, 4)))

        wrapper = MetadataWrapper(parse_module("pass"))
        wrapper.visit_batched([ABatchable()])
예제 #16
0
    def test_multiline_string_position(self) -> None:
        wrapper = MetadataWrapper(parse_module('"abc"\\\n"def"'))
        module = wrapper.module
        positions = wrapper.resolve(PositionProvider)

        stmt = cast(cst.SimpleStatementLine, module.body[0])
        expr = cast(cst.Expr, stmt.body[0])
        string = expr.value

        self.cmp_position(positions[stmt], (1, 0), (2, 5))
        self.cmp_position(positions[expr], (1, 0), (2, 5))
        self.cmp_position(positions[string], (1, 0), (2, 5))
 def test_expressions_with_assign(self) -> None:
     wrapper = MetadataWrapper(parse_module("f(a)[b] = c"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.LOAD,
                 "b": ExpressionContext.LOAD,
                 "c": ExpressionContext.LOAD,
                 "f": ExpressionContext.LOAD,
             },
             subscript_to_context={"f(a)[b]": ExpressionContext.STORE},
         ))
 def test_list_with_assing(self) -> None:
     wrapper = MetadataWrapper(parse_module("[a] = [b]"))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "a": ExpressionContext.STORE,
                 "b": ExpressionContext.LOAD,
             },
             list_to_context={
                 ("a", ): ExpressionContext.STORE,
                 ("b", ): ExpressionContext.LOAD,
             },
         ))
    def test_assign_with_subscript(self) -> None:
        wrapper = MetadataWrapper(parse_module("a[b] = c[d]"))
        wrapper.visit(
            DependentVisitor(
                test=self,
                name_to_context={
                    "a": ExpressionContext.STORE,
                    "b": ExpressionContext.LOAD,
                    "c": ExpressionContext.LOAD,
                    "d": ExpressionContext.LOAD,
                },
                subscript_to_context={
                    "a[b]": ExpressionContext.STORE,
                    "c[d]": ExpressionContext.LOAD,
                },
            )
        )

        wrapper = MetadataWrapper(parse_module("x.y[start:end, idx]"))
        wrapper.visit(
            DependentVisitor(
                test=self,
                name_to_context={
                    "x": ExpressionContext.LOAD,
                    "y": None,
                    "start": ExpressionContext.LOAD,
                    "end": ExpressionContext.LOAD,
                    "idx": ExpressionContext.LOAD,
                },
                subscript_to_context={"x.y[start:end, idx]": ExpressionContext.LOAD},
                attribute_to_context={"x.y": ExpressionContext.LOAD},
            )
        )
예제 #20
0
 def test_byte_conversion(self, ) -> None:
     module_bytes = "fn()\n".encode("utf-16")
     mw = MetadataWrapper(
         cst.parse_module("fn()\n",
                          cst.PartialParserConfig(encoding="utf-16")))
     codegen_partial = mw.resolve(ExperimentalReentrantCodegenProvider)[
         mw.module.body[0]]
     self.assertEqual(codegen_partial.get_original_module_bytes(),
                      module_bytes)
     self.assertEqual(
         codegen_partial.get_modified_module_bytes(
             cst.parse_statement("fn2()\n")),
         "fn2()\n".encode("utf-16"),
     )
예제 #21
0
    def test_resolve_dependent_provider_twice(self) -> None:
        """
        Tests that resolving the same provider twice is a no-op
        """
        mock = Mock()

        class ProviderA(VisitorMetadataProvider[bool]):
            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_a()

        class ProviderB(VisitorMetadataProvider[bool]):
            METADATA_DEPENDENCIES = (ProviderA, )

            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_b()

        module = cst.parse_module("pass")
        wrapper = MetadataWrapper(module)

        wrapper.resolve(ProviderA)
        mock.visited_a.assert_called_once()

        wrapper.resolve(ProviderB)
        mock.visited_a.assert_called_once()
        mock.visited_b.assert_called_once()

        wrapper.resolve(ProviderA)
        mock.visited_a.assert_called_once()
        mock.visited_b.assert_called_once()
예제 #22
0
    def test_accesses(self) -> None:
        m, scopes = get_scope_metadata_provider("""
            foo = 'toplevel'
            fn1(foo)
            fn2(foo)
            def fn_def():
                foo = 'shadow'
                fn3(foo)
            """)
        scope_of_module = scopes[m]
        self.assertIsInstance(scope_of_module, GlobalScope)
        global_foo_assignments = list(scope_of_module["foo"])
        self.assertEqual(len(global_foo_assignments), 1)
        foo_assignment = global_foo_assignments[0]
        self.assertEqual(len(foo_assignment.references), 2)
        fn1_call_arg = ensure_type(
            ensure_type(
                ensure_type(m.body[1], cst.SimpleStatementLine).body[0],
                cst.Expr).value,
            cst.Call,
        ).args[0]

        fn2_call_arg = ensure_type(
            ensure_type(
                ensure_type(m.body[2], cst.SimpleStatementLine).body[0],
                cst.Expr).value,
            cst.Call,
        ).args[0]
        self.assertEqual(
            {access.node
             for access in foo_assignment.references},
            {fn1_call_arg.value, fn2_call_arg.value},
        )
        func_body = ensure_type(m.body[3], cst.FunctionDef).body
        func_foo_statement = func_body.body[0]
        scope_of_func_statement = scopes[func_foo_statement]
        self.assertIsInstance(scope_of_func_statement, FunctionScope)
        func_foo_assignments = scope_of_func_statement["foo"]
        self.assertEqual(len(func_foo_assignments), 1)
        foo_assignment = list(func_foo_assignments)[0]
        self.assertEqual(len(foo_assignment.references), 1)
        fn3_call_arg = ensure_type(
            ensure_type(
                ensure_type(func_body.body[1],
                            cst.SimpleStatementLine).body[0],
                cst.Expr,
            ).value,
            cst.Call,
        ).args[0]
        self.assertEqual({access.node
                          for access in foo_assignment.references},
                         {fn3_call_arg.value})

        wrapper = MetadataWrapper(cst.parse_module("from a import b\n"))
        wrapper.visit(DependentVisitor())

        wrapper = MetadataWrapper(
            cst.parse_module("def a():\n    from b import c\n\n"))
        wrapper.visit(DependentVisitor())
 def test_function(self) -> None:
     code = """def foo(x: int = y) -> None: pass"""
     wrapper = MetadataWrapper(parse_module(code))
     wrapper.visit(
         DependentVisitor(
             test=self,
             name_to_context={
                 "foo": ExpressionContext.STORE,
                 "x": ExpressionContext.STORE,
                 "int": ExpressionContext.LOAD,
                 "y": ExpressionContext.LOAD,
                 "None": ExpressionContext.LOAD,
             },
         )
     )
예제 #24
0
    def test_inherited_metadata(self) -> None:
        """
        Tests that classes inherit access to metadata declared by their base
        classes.
        """
        test_runner = self
        mock = Mock()

        class SimpleProvider(VisitorMetadataProvider[int]):
            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_simple()
                self.set_metadata(node, 1)

        class VisitorA(CSTTransformer):
            METADATA_DEPENDENCIES = (SimpleProvider,)

        class VisitorB(VisitorA):
            def visit_Pass(self, node: cst.Pass) -> None:
                test_runner.assertEqual(self.get_metadata(SimpleProvider, node), 1)

        module = parse_module("pass")
        MetadataWrapper(module).visit(VisitorB())

        # Check each visitor is called once
        mock.visited_simple.assert_called_once()
예제 #25
0
    def test_batchable_provider(self) -> None:
        class SimpleProvider(BatchableMetadataProvider[int]):
            """
            Sets metadata on every pass node to 1 and every return node to 2.
            """
            def visit_Pass(self, node: cst.Pass) -> None:
                self.set_metadata(node, 1)

            def visit_Return(self, node: cst.Return) -> None:
                self.set_metadata(node, 2)

        wrapper = MetadataWrapper(parse_module("pass; return; pass"))
        module = wrapper.module
        pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0]
        return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1]
        pass_2 = cast(cst.SimpleStatementLine, module.body[0]).body[2]

        provider = SimpleProvider()
        metadata = _gen_batchable(wrapper, [provider])

        # Check access on provider
        self.assertEqual(provider.get_metadata(SimpleProvider, pass_), 1)
        self.assertEqual(provider.get_metadata(SimpleProvider, return_), 2)
        self.assertEqual(provider.get_metadata(SimpleProvider, pass_2), 1)

        # Check returned mapping
        self.assertEqual(metadata[SimpleProvider][pass_], 1)
        self.assertEqual(metadata[SimpleProvider][return_], 2)
        self.assertEqual(metadata[SimpleProvider][pass_2], 1)
예제 #26
0
    def test_visitor_provider(self) -> None:
        class SimpleProvider(VisitorMetadataProvider[int]):
            """
            Sets metadata on every node to 1.
            """
            def on_visit(self, node: cst.CSTNode) -> bool:
                self.set_metadata(node, 1)
                return True

        wrapper = MetadataWrapper(parse_module("pass; return"))
        module = wrapper.module
        pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0]
        return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1]

        provider = SimpleProvider()
        metadata = provider._gen(wrapper)

        # Check access on provider
        self.assertEqual(provider.get_metadata(SimpleProvider, module), 1)
        self.assertEqual(provider.get_metadata(SimpleProvider, pass_), 1)
        self.assertEqual(provider.get_metadata(SimpleProvider, return_), 1)

        # Check returned mapping
        self.assertEqual(metadata[module], 1)
        self.assertEqual(metadata[pass_], 1)
        self.assertEqual(metadata[return_], 1)
예제 #27
0
def get_file_lint_result_json(
    path: Path,
    opts: LintOpts,
    metadata_cache: Optional[Mapping["ProviderT", object]] = None,
) -> Sequence[str]:
    try:
        with open(path, "rb") as f:
            source = f.read()
        cst_wrapper = None
        if metadata_cache is not None:
            cst_wrapper = MetadataWrapper(
                cst.parse_module(source),
                True,
                metadata_cache,
            )
        results = opts.success_report.create_reports(
            path,
            lint_file(
                path,
                source,
                rules=opts.rules,
                config=opts.config,
                cst_wrapper=cst_wrapper,
            ),
            **opts.extra,
        )
    except Exception:
        tb_str = traceback.format_exc()
        results = opts.failure_report.create_reports(path, tb_str,
                                                     **opts.extra)
    return [json.dumps(asdict(r)) for r in results]
예제 #28
0
 def _collect_statistics(
         self, modules: Mapping[str, cst.Module]) -> Dict[str, Any]:
     modules_with_metadata: Mapping[str, cst.MetadataWrapper] = {
         path: MetadataWrapper(module)
         for path, module in modules.items()
     }
     annotations = _path_wise_counts(modules_with_metadata,
                                     AnnotationCountCollector)
     fixmes = _path_wise_counts(modules, FixmeCountCollector)
     ignores = _path_wise_counts(modules, IgnoreCountCollector)
     strict_files = _path_wise_counts(
         modules,
         StrictCountCollector,
         self._configuration.strict,
     )
     return {
         "annotations": {
             path: counts.build_json()
             for path, counts in annotations.items()
         },
         "fixmes":
         {path: counts.build_json()
          for path, counts in fixmes.items()},
         "ignores":
         {path: counts.build_json()
          for path, counts in ignores.items()},
         "strict": {
             path: counts.build_json()
             for path, counts in strict_files.items()
         },
     }
예제 #29
0
    def test_batchable_provider_inherited_metadata(self) -> None:
        """
        Tests that batchable providers inherit access to metadata declared by
        their base classes.
        """
        test_runner = self
        mock = Mock()

        class ProviderA(VisitorMetadataProvider[int]):
            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_a()
                self.set_metadata(node, 1)

        class ProviderB(BatchableMetadataProvider[int]):
            METADATA_DEPENDENCIES = (ProviderA,)

        class ProviderC(ProviderB):
            def visit_Pass(self, node: cst.Pass) -> None:
                mock.visited_c()
                test_runner.assertEqual(self.get_metadata(ProviderA, node), 1)

        class VisitorA(CSTTransformer):
            METADATA_DEPENDENCIES = (ProviderC,)

        module = parse_module("pass")
        MetadataWrapper(module).visit(VisitorA())

        # Check each visitor is called once
        mock.visited_a.assert_called_once()
        mock.visited_c.assert_called_once()
예제 #30
0
def get_one_patchable_report_for_path(
    path: Path,
    source: bytes,
    rules: LintRuleCollectionT,
    use_ignore_byte_markers: bool,
    use_ignore_comments: bool,
    metadata_cache: Optional[Mapping["ProviderT", object]],
) -> LintRuleReportsWithAppliedPatches:
    cst_wrapper: Optional[MetadataWrapper] = None
    if metadata_cache is not None:
        cst_wrapper = MetadataWrapper(
            parse_module(source),
            True,
            metadata_cache,
        )

    return lint_file_and_apply_patches(
        path,
        source,
        rules=rules,
        use_ignore_byte_markers=use_ignore_byte_markers,
        use_ignore_comments=use_ignore_comments,
        # We will need to regenerate metadata cache every time a patch is applied.
        max_iter=1,
        cst_wrapper=cst_wrapper,
        find_unused_suppressions=True,
    )