def get( wrapper: MetadataWrapper, original_node: cst.CSTNode, replacement_node: Union[cst.CSTNode, cst.RemovalSentinel], ) -> "LintPatch": # Batch the execution of these position providers wrapper.resolve_many( [ ParentNodeProvider, ExperimentalReentrantCodegenProvider, WhitespaceInclusivePositionProvider, ] ) # Use the resolve() API to fetch the data, because it's typed better than # resolve_many() is. parents = wrapper.resolve(ParentNodeProvider) positions = wrapper.resolve(WhitespaceInclusivePositionProvider) codegen_partials = wrapper.resolve(ExperimentalReentrantCodegenProvider) if isinstance(original_node, cst.Module) and isinstance( replacement_node, cst.RemovalSentinel ): raise Exception("Removing the entire module is not possible") # The reentrant codegen provider can only rewrite entire statements at a time, # so we need to inspect our parents until find a statement or the module possible_statement = original_node if isinstance(replacement_node, cst.RemovalSentinel): # reentrant codegen doesn't support RemovalSentinel, so use the parent instead possible_statement = parents[possible_statement] while True: if possible_statement in codegen_partials: partial = codegen_partials[possible_statement] patched_statement = cst.ensure_type( _replace_or_remove( possible_statement, original_node, replacement_node ), cst.BaseStatement, ) original_str = partial.get_original_statement_code() patched_str = partial.get_modified_statement_code(patched_statement) return LintPatch( partial.start_offset, positions[possible_statement].start, original_str, patched_str, ) elif possible_statement in parents: possible_statement = parents[possible_statement] else: # There's no more parents, so we have to fall back to replacing the whole # module. original_str = wrapper.module.code patched_module = cst.ensure_type( _replace_or_remove(wrapper.module, original_node, replacement_node), cst.Module, ) patched_str = patched_module.code return LintPatch(0, CodePosition(1, 0), original_str, patched_str)
def test_equality_by_identity(self) -> None: m = cst.parse_module("pass") mw1 = MetadataWrapper(m) mw2 = MetadataWrapper(m) self.assertEqual(mw1, mw1) self.assertEqual(mw2, mw2) self.assertNotEqual(mw1, mw2)
def get_scope_metadata_provider( module_str: str, ) -> Tuple[cst.Module, Mapping[cst.CSTNode, Scope]]: wrapper = MetadataWrapper(cst.parse_module(dedent(module_str))) return ( wrapper.module, cast(Mapping[cst.CSTNode, Scope], wrapper.resolve( ScopeProvider)), # we're sure every node has an associated scope )
def test_with_as(self) -> None: wrapper = MetadataWrapper(parse_module("with a() as b:\n pass")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.LOAD, "b": ExpressionContext.STORE, }, ))
def test_simple_assign(self) -> None: wrapper = MetadataWrapper(parse_module("a = b")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.STORE, "b": ExpressionContext.LOAD, }, ))
def test_function_position(self) -> None: wrapper = MetadataWrapper(parse_module("def foo():\n pass")) module = wrapper.module positions = wrapper.resolve(PositionProvider) fn = cast(cst.FunctionDef, module.body[0]) stmt = cast(cst.SimpleStatementLine, fn.body.body[0]) pass_stmt = cast(cst.Pass, stmt.body[0]) self.cmp_position(positions[stmt], (2, 4), (2, 8)) self.cmp_position(positions[pass_stmt], (2, 4), (2, 8))
def test_del_with_subscript(self) -> None: wrapper = MetadataWrapper(parse_module("del a[b]")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.LOAD, "b": ExpressionContext.LOAD, }, subscript_to_context={"a": ExpressionContext.DEL}, ))
def test_starred_element_with_assign(self) -> None: wrapper = MetadataWrapper(parse_module("*a = b")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.LOAD, "b": ExpressionContext.LOAD, }, starred_element_to_context={"a": ExpressionContext.STORE}, ))
def get_qualified_name_metadata_provider( module_str: str ) -> Tuple[cst.Module, Mapping[cst.CSTNode, Collection[QualifiedName]]]: wrapper = MetadataWrapper(cst.parse_module(dedent(module_str))) return ( wrapper.module, cast( Mapping[cst.CSTNode, Collection[QualifiedName]], wrapper.resolve(QualifiedNameProvider), ), # we're sure every node has an associated scope )
def test_for(self) -> None: wrapper = MetadataWrapper(parse_module("for i in items:\n j = 1")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "i": ExpressionContext.STORE, "items": ExpressionContext.LOAD, "j": ExpressionContext.STORE, }, ))
def test_del_with_tuple(self) -> None: wrapper = MetadataWrapper(parse_module("del a, b")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.DEL, "b": ExpressionContext.DEL, }, tuple_to_context={("a", "b"): ExpressionContext.DEL}, ))
def test_except_as(self) -> None: wrapper = MetadataWrapper( parse_module("try: ...\nexcept Exception as ex:\n pass")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "Exception": ExpressionContext.LOAD, "ex": ExpressionContext.STORE, }, ))
def test_hash_by_identity(self) -> None: m = cst.parse_module("pass") mw1 = MetadataWrapper(m) mw2 = MetadataWrapper(m, unsafe_skip_copy=True) mw3 = MetadataWrapper(m, unsafe_skip_copy=True) self.assertEqual(hash(mw1), hash(mw1)) self.assertEqual(hash(mw2), hash(mw2)) self.assertEqual(hash(mw3), hash(mw3)) self.assertNotEqual(hash(mw1), hash(mw2)) self.assertNotEqual(hash(mw1), hash(mw3)) self.assertNotEqual(hash(mw2), hash(mw3))
def refactor_string(source, unused_imports): try: wrapper = MetadataWrapper(cst.parse_module(source)) except cst.ParserSyntaxError as err: print(Color(str(err)).red) else: if unused_imports: fixed_module = wrapper.visit( RemoveUnusedImportTransformer(unused_imports)) return fixed_module.code return source
def test_batchable_provider(self) -> None: test = self class ABatchable(BatchableCSTVisitor): METADATA_DEPENDENCIES = (PositionProvider, ) def visit_Pass(self, node: cst.Pass) -> None: test.assertEqual(self.get_metadata(PositionProvider, node), CodeRange((1, 0), (1, 4))) wrapper = MetadataWrapper(parse_module("pass")) wrapper.visit_batched([ABatchable()])
def test_multiline_string_position(self) -> None: wrapper = MetadataWrapper(parse_module('"abc"\\\n"def"')) module = wrapper.module positions = wrapper.resolve(PositionProvider) stmt = cast(cst.SimpleStatementLine, module.body[0]) expr = cast(cst.Expr, stmt.body[0]) string = expr.value self.cmp_position(positions[stmt], (1, 0), (2, 5)) self.cmp_position(positions[expr], (1, 0), (2, 5)) self.cmp_position(positions[string], (1, 0), (2, 5))
def test_expressions_with_assign(self) -> None: wrapper = MetadataWrapper(parse_module("f(a)[b] = c")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.LOAD, "b": ExpressionContext.LOAD, "c": ExpressionContext.LOAD, "f": ExpressionContext.LOAD, }, subscript_to_context={"f(a)[b]": ExpressionContext.STORE}, ))
def test_list_with_assing(self) -> None: wrapper = MetadataWrapper(parse_module("[a] = [b]")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.STORE, "b": ExpressionContext.LOAD, }, list_to_context={ ("a", ): ExpressionContext.STORE, ("b", ): ExpressionContext.LOAD, }, ))
def test_assign_with_subscript(self) -> None: wrapper = MetadataWrapper(parse_module("a[b] = c[d]")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.STORE, "b": ExpressionContext.LOAD, "c": ExpressionContext.LOAD, "d": ExpressionContext.LOAD, }, subscript_to_context={ "a[b]": ExpressionContext.STORE, "c[d]": ExpressionContext.LOAD, }, ) ) wrapper = MetadataWrapper(parse_module("x.y[start:end, idx]")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "x": ExpressionContext.LOAD, "y": None, "start": ExpressionContext.LOAD, "end": ExpressionContext.LOAD, "idx": ExpressionContext.LOAD, }, subscript_to_context={"x.y[start:end, idx]": ExpressionContext.LOAD}, attribute_to_context={"x.y": ExpressionContext.LOAD}, ) )
def test_byte_conversion(self, ) -> None: module_bytes = "fn()\n".encode("utf-16") mw = MetadataWrapper( cst.parse_module("fn()\n", cst.PartialParserConfig(encoding="utf-16"))) codegen_partial = mw.resolve(ExperimentalReentrantCodegenProvider)[ mw.module.body[0]] self.assertEqual(codegen_partial.get_original_module_bytes(), module_bytes) self.assertEqual( codegen_partial.get_modified_module_bytes( cst.parse_statement("fn2()\n")), "fn2()\n".encode("utf-16"), )
def test_resolve_dependent_provider_twice(self) -> None: """ Tests that resolving the same provider twice is a no-op """ mock = Mock() class ProviderA(VisitorMetadataProvider[bool]): def visit_Pass(self, node: cst.Pass) -> None: mock.visited_a() class ProviderB(VisitorMetadataProvider[bool]): METADATA_DEPENDENCIES = (ProviderA, ) def visit_Pass(self, node: cst.Pass) -> None: mock.visited_b() module = cst.parse_module("pass") wrapper = MetadataWrapper(module) wrapper.resolve(ProviderA) mock.visited_a.assert_called_once() wrapper.resolve(ProviderB) mock.visited_a.assert_called_once() mock.visited_b.assert_called_once() wrapper.resolve(ProviderA) mock.visited_a.assert_called_once() mock.visited_b.assert_called_once()
def test_accesses(self) -> None: m, scopes = get_scope_metadata_provider(""" foo = 'toplevel' fn1(foo) fn2(foo) def fn_def(): foo = 'shadow' fn3(foo) """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) global_foo_assignments = list(scope_of_module["foo"]) self.assertEqual(len(global_foo_assignments), 1) foo_assignment = global_foo_assignments[0] self.assertEqual(len(foo_assignment.references), 2) fn1_call_arg = ensure_type( ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr).value, cst.Call, ).args[0] fn2_call_arg = ensure_type( ensure_type( ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr).value, cst.Call, ).args[0] self.assertEqual( {access.node for access in foo_assignment.references}, {fn1_call_arg.value, fn2_call_arg.value}, ) func_body = ensure_type(m.body[3], cst.FunctionDef).body func_foo_statement = func_body.body[0] scope_of_func_statement = scopes[func_foo_statement] self.assertIsInstance(scope_of_func_statement, FunctionScope) func_foo_assignments = scope_of_func_statement["foo"] self.assertEqual(len(func_foo_assignments), 1) foo_assignment = list(func_foo_assignments)[0] self.assertEqual(len(foo_assignment.references), 1) fn3_call_arg = ensure_type( ensure_type( ensure_type(func_body.body[1], cst.SimpleStatementLine).body[0], cst.Expr, ).value, cst.Call, ).args[0] self.assertEqual({access.node for access in foo_assignment.references}, {fn3_call_arg.value}) wrapper = MetadataWrapper(cst.parse_module("from a import b\n")) wrapper.visit(DependentVisitor()) wrapper = MetadataWrapper( cst.parse_module("def a():\n from b import c\n\n")) wrapper.visit(DependentVisitor())
def test_function(self) -> None: code = """def foo(x: int = y) -> None: pass""" wrapper = MetadataWrapper(parse_module(code)) wrapper.visit( DependentVisitor( test=self, name_to_context={ "foo": ExpressionContext.STORE, "x": ExpressionContext.STORE, "int": ExpressionContext.LOAD, "y": ExpressionContext.LOAD, "None": ExpressionContext.LOAD, }, ) )
def test_inherited_metadata(self) -> None: """ Tests that classes inherit access to metadata declared by their base classes. """ test_runner = self mock = Mock() class SimpleProvider(VisitorMetadataProvider[int]): def visit_Pass(self, node: cst.Pass) -> None: mock.visited_simple() self.set_metadata(node, 1) class VisitorA(CSTTransformer): METADATA_DEPENDENCIES = (SimpleProvider,) class VisitorB(VisitorA): def visit_Pass(self, node: cst.Pass) -> None: test_runner.assertEqual(self.get_metadata(SimpleProvider, node), 1) module = parse_module("pass") MetadataWrapper(module).visit(VisitorB()) # Check each visitor is called once mock.visited_simple.assert_called_once()
def test_batchable_provider(self) -> None: class SimpleProvider(BatchableMetadataProvider[int]): """ Sets metadata on every pass node to 1 and every return node to 2. """ def visit_Pass(self, node: cst.Pass) -> None: self.set_metadata(node, 1) def visit_Return(self, node: cst.Return) -> None: self.set_metadata(node, 2) wrapper = MetadataWrapper(parse_module("pass; return; pass")) module = wrapper.module pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0] return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1] pass_2 = cast(cst.SimpleStatementLine, module.body[0]).body[2] provider = SimpleProvider() metadata = _gen_batchable(wrapper, [provider]) # Check access on provider self.assertEqual(provider.get_metadata(SimpleProvider, pass_), 1) self.assertEqual(provider.get_metadata(SimpleProvider, return_), 2) self.assertEqual(provider.get_metadata(SimpleProvider, pass_2), 1) # Check returned mapping self.assertEqual(metadata[SimpleProvider][pass_], 1) self.assertEqual(metadata[SimpleProvider][return_], 2) self.assertEqual(metadata[SimpleProvider][pass_2], 1)
def test_visitor_provider(self) -> None: class SimpleProvider(VisitorMetadataProvider[int]): """ Sets metadata on every node to 1. """ def on_visit(self, node: cst.CSTNode) -> bool: self.set_metadata(node, 1) return True wrapper = MetadataWrapper(parse_module("pass; return")) module = wrapper.module pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0] return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1] provider = SimpleProvider() metadata = provider._gen(wrapper) # Check access on provider self.assertEqual(provider.get_metadata(SimpleProvider, module), 1) self.assertEqual(provider.get_metadata(SimpleProvider, pass_), 1) self.assertEqual(provider.get_metadata(SimpleProvider, return_), 1) # Check returned mapping self.assertEqual(metadata[module], 1) self.assertEqual(metadata[pass_], 1) self.assertEqual(metadata[return_], 1)
def get_file_lint_result_json( path: Path, opts: LintOpts, metadata_cache: Optional[Mapping["ProviderT", object]] = None, ) -> Sequence[str]: try: with open(path, "rb") as f: source = f.read() cst_wrapper = None if metadata_cache is not None: cst_wrapper = MetadataWrapper( cst.parse_module(source), True, metadata_cache, ) results = opts.success_report.create_reports( path, lint_file( path, source, rules=opts.rules, config=opts.config, cst_wrapper=cst_wrapper, ), **opts.extra, ) except Exception: tb_str = traceback.format_exc() results = opts.failure_report.create_reports(path, tb_str, **opts.extra) return [json.dumps(asdict(r)) for r in results]
def _collect_statistics( self, modules: Mapping[str, cst.Module]) -> Dict[str, Any]: modules_with_metadata: Mapping[str, cst.MetadataWrapper] = { path: MetadataWrapper(module) for path, module in modules.items() } annotations = _path_wise_counts(modules_with_metadata, AnnotationCountCollector) fixmes = _path_wise_counts(modules, FixmeCountCollector) ignores = _path_wise_counts(modules, IgnoreCountCollector) strict_files = _path_wise_counts( modules, StrictCountCollector, self._configuration.strict, ) return { "annotations": { path: counts.build_json() for path, counts in annotations.items() }, "fixmes": {path: counts.build_json() for path, counts in fixmes.items()}, "ignores": {path: counts.build_json() for path, counts in ignores.items()}, "strict": { path: counts.build_json() for path, counts in strict_files.items() }, }
def test_batchable_provider_inherited_metadata(self) -> None: """ Tests that batchable providers inherit access to metadata declared by their base classes. """ test_runner = self mock = Mock() class ProviderA(VisitorMetadataProvider[int]): def visit_Pass(self, node: cst.Pass) -> None: mock.visited_a() self.set_metadata(node, 1) class ProviderB(BatchableMetadataProvider[int]): METADATA_DEPENDENCIES = (ProviderA,) class ProviderC(ProviderB): def visit_Pass(self, node: cst.Pass) -> None: mock.visited_c() test_runner.assertEqual(self.get_metadata(ProviderA, node), 1) class VisitorA(CSTTransformer): METADATA_DEPENDENCIES = (ProviderC,) module = parse_module("pass") MetadataWrapper(module).visit(VisitorA()) # Check each visitor is called once mock.visited_a.assert_called_once() mock.visited_c.assert_called_once()
def get_one_patchable_report_for_path( path: Path, source: bytes, rules: LintRuleCollectionT, use_ignore_byte_markers: bool, use_ignore_comments: bool, metadata_cache: Optional[Mapping["ProviderT", object]], ) -> LintRuleReportsWithAppliedPatches: cst_wrapper: Optional[MetadataWrapper] = None if metadata_cache is not None: cst_wrapper = MetadataWrapper( parse_module(source), True, metadata_cache, ) return lint_file_and_apply_patches( path, source, rules=rules, use_ignore_byte_markers=use_ignore_byte_markers, use_ignore_comments=use_ignore_comments, # We will need to regenerate metadata cache every time a patch is applied. max_iter=1, cst_wrapper=cst_wrapper, find_unused_suppressions=True, )