def test_assign_with_subscript(self) -> None: wrapper = MetadataWrapper(parse_module("a[b] = c[d]")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.STORE, "b": ExpressionContext.LOAD, "c": ExpressionContext.LOAD, "d": ExpressionContext.LOAD, }, subscript_to_context={ "a[b]": ExpressionContext.STORE, "c[d]": ExpressionContext.LOAD, }, ) ) wrapper = MetadataWrapper(parse_module("x.y[start:end, idx]")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "x": ExpressionContext.LOAD, "y": None, "start": ExpressionContext.LOAD, "end": ExpressionContext.LOAD, "idx": ExpressionContext.LOAD, }, subscript_to_context={"x.y[start:end, idx]": ExpressionContext.LOAD}, attribute_to_context={"x.y": ExpressionContext.LOAD}, ) )
def test_metadata_cache(self) -> None: class DummyMetadataProvider(BatchableMetadataProvider[None]): gen_cache = tuple m = cst.parse_module("pass") mw = MetadataWrapper(m) with self.assertRaisesRegex( Exception, "Cache is required for initializing DummyMetadataProvider."): mw.resolve(DummyMetadataProvider) class SimpleCacheMetadataProvider(BatchableMetadataProvider[object]): gen_cache = tuple def __init__(self, cache: object) -> None: super().__init__(cache) self.cache = cache def visit_Pass(self, node: cst.Pass) -> Optional[bool]: self.set_metadata(node, self.cache) cached_data = object() # pyre-fixme[6]: Expected `Mapping[Type[BaseMetadataProvider[object]], # object]` for 2nd param but got `Dict[Type[SimpleCacheMetadataProvider], # object]`. mw = MetadataWrapper(m, cache={SimpleCacheMetadataProvider: cached_data}) pass_node = cst.ensure_type(mw.module.body[0], cst.SimpleStatementLine).body[0] self.assertEqual( mw.resolve(SimpleCacheMetadataProvider)[pass_node], cached_data)
def test_equality_by_identity(self) -> None: m = cst.parse_module("pass") mw1 = MetadataWrapper(m) mw2 = MetadataWrapper(m) self.assertEqual(mw1, mw1) self.assertEqual(mw2, mw2) self.assertNotEqual(mw1, mw2)
def test_accesses(self) -> None: m, scopes = get_scope_metadata_provider(""" foo = 'toplevel' fn1(foo) fn2(foo) def fn_def(): foo = 'shadow' fn3(foo) """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) global_foo_assignments = list(scope_of_module["foo"]) self.assertEqual(len(global_foo_assignments), 1) foo_assignment = global_foo_assignments[0] self.assertEqual(len(foo_assignment.references), 2) fn1_call_arg = ensure_type( ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr).value, cst.Call, ).args[0] fn2_call_arg = ensure_type( ensure_type( ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr).value, cst.Call, ).args[0] self.assertEqual( {access.node for access in foo_assignment.references}, {fn1_call_arg.value, fn2_call_arg.value}, ) func_body = ensure_type(m.body[3], cst.FunctionDef).body func_foo_statement = func_body.body[0] scope_of_func_statement = scopes[func_foo_statement] self.assertIsInstance(scope_of_func_statement, FunctionScope) func_foo_assignments = scope_of_func_statement["foo"] self.assertEqual(len(func_foo_assignments), 1) foo_assignment = list(func_foo_assignments)[0] self.assertEqual(len(foo_assignment.references), 1) fn3_call_arg = ensure_type( ensure_type( ensure_type(func_body.body[1], cst.SimpleStatementLine).body[0], cst.Expr, ).value, cst.Call, ).args[0] self.assertEqual({access.node for access in foo_assignment.references}, {fn3_call_arg.value}) wrapper = MetadataWrapper(cst.parse_module("from a import b\n")) wrapper.visit(DependentVisitor()) wrapper = MetadataWrapper( cst.parse_module("def a():\n from b import c\n\n")) wrapper.visit(DependentVisitor())
def test_line_count(self) -> None: source_module = MetadataWrapper( cst.parse_module("# No trailing newline")) collector = AnnotationCollector() source_module.visit(collector) self.assertEqual(collector.line_count, 1) source_module = MetadataWrapper( cst.parse_module("# With trailing newline\n")) collector = AnnotationCollector() source_module.visit(collector) self.assertEqual(collector.line_count, 2)
def test_hash_by_identity(self) -> None: m = cst.parse_module("pass") mw1 = MetadataWrapper(m) mw2 = MetadataWrapper(m, unsafe_skip_copy=True) mw3 = MetadataWrapper(m, unsafe_skip_copy=True) self.assertEqual(hash(mw1), hash(mw1)) self.assertEqual(hash(mw2), hash(mw2)) self.assertEqual(hash(mw3), hash(mw3)) self.assertNotEqual(hash(mw1), hash(mw2)) self.assertNotEqual(hash(mw1), hash(mw3)) self.assertNotEqual(hash(mw2), hash(mw3))
def get_formatted_reports_for_path( path: Path, opts: LintOpts, metadata_cache: Optional[Mapping["ProviderT", object]] = None, ) -> Iterable[str]: with open(path, "rb") as f: source = f.read() try: cst_wrapper = None if metadata_cache is not None: cst_wrapper = MetadataWrapper(parse_module(source), True, metadata_cache) raw_reports = lint_file( path, source, rules=opts.rules, use_ignore_byte_markers=opts.use_ignore_byte_markers, use_ignore_comments=opts.use_ignore_comments, cst_wrapper=cst_wrapper, find_unused_suppressions=True, ) except (SyntaxError, ParserSyntaxError) as e: print_red( f"Encountered the following error while parsing source code in file {path}:" ) print(e) return [] # linter completed successfully return [opts.formatter.format(rr) for rr in raw_reports]
def _collect_statistics(self, modules: Dict[str, cst.Module]) -> Dict[str, Any]: modules_with_metadata = { path: MetadataWrapper(module) for path, module in modules.items() } # pyre-fixme[6]: Expected `Dict[str, Union[cst._nodes.module.Module, # cst.metadata.wrapper.MetadataWrapper]]` for 1st param but got `Dict[str, # cst.metadata.wrapper.MetadataWrapper]`. annotations = _path_wise_counts(modules_with_metadata, AnnotationCountCollector) # pyre-fixme[6]: Expected `Dict[str, Union[cst._nodes.module.Module, # cst.metadata.wrapper.MetadataWrapper]]` for 1st param but got `Dict[str, # cst._nodes.module.Module]`. fixmes = _path_wise_counts(modules, FixmeCountCollector) # pyre-fixme[6]: Expected `Dict[str, Union[cst._nodes.module.Module, # cst.metadata.wrapper.MetadataWrapper]]` for 1st param but got `Dict[str, # cst._nodes.module.Module]`. ignores = _path_wise_counts(modules, IgnoreCountCollector) strict_files = _path_wise_counts( # pyre-fixme[6]: Expected `Dict[str, Union[cst._nodes.module.Module, # cst.metadata.wrapper.MetadataWrapper]]` for 1st param but got `Dict[str, # cst._nodes.module.Module]`. modules, StrictCountCollector, self._configuration.strict, ) return { "annotations": { path: counts.build_json() for path, counts in annotations.items() }, "fixmes": {path: counts.build_json() for path, counts in fixmes.items()}, "ignores": {path: counts.build_json() for path, counts in ignores.items()}, "strict": { path: counts.build_json() for path, counts in strict_files.items() }, }
def test_resolve_dependent_provider_twice(self) -> None: """ Tests that resolving the same provider twice is a no-op """ mock = Mock() class ProviderA(VisitorMetadataProvider[bool]): def visit_Pass(self, node: cst.Pass) -> None: mock.visited_a() class ProviderB(VisitorMetadataProvider[bool]): METADATA_DEPENDENCIES = (ProviderA, ) def visit_Pass(self, node: cst.Pass) -> None: mock.visited_b() module = cst.parse_module("pass") wrapper = MetadataWrapper(module) wrapper.resolve(ProviderA) mock.visited_a.assert_called_once() wrapper.resolve(ProviderB) mock.visited_a.assert_called_once() mock.visited_b.assert_called_once() wrapper.resolve(ProviderA) mock.visited_a.assert_called_once() mock.visited_b.assert_called_once()
def get_file_lint_result_json( path: Path, opts: LintOpts, metadata_cache: Optional[Mapping["ProviderT", object]] = None, ) -> Sequence[str]: try: with open(path, "rb") as f: source = f.read() cst_wrapper = None if metadata_cache is not None: cst_wrapper = MetadataWrapper( cst.parse_module(source), True, metadata_cache, ) results = opts.success_report.create_reports( path, lint_file( path, source, rules=opts.rules, config=opts.config, cst_wrapper=cst_wrapper, ), **opts.extra, ) except Exception: tb_str = traceback.format_exc() results = opts.failure_report.create_reports(path, tb_str, **opts.extra) return [json.dumps(asdict(r)) for r in results]
def test_inherited_metadata(self) -> None: """ Tests that classes inherit access to metadata declared by their base classes. """ test_runner = self mock = Mock() class SimpleProvider(VisitorMetadataProvider[int]): def visit_Pass(self, node: cst.Pass) -> None: mock.visited_simple() self.set_metadata(node, 1) class VisitorA(CSTTransformer): METADATA_DEPENDENCIES = (SimpleProvider,) class VisitorB(VisitorA): def visit_Pass(self, node: cst.Pass) -> None: test_runner.assertEqual(self.get_metadata(SimpleProvider, node), 1) module = parse_module("pass") MetadataWrapper(module).visit(VisitorB()) # Check each visitor is called once mock.visited_simple.assert_called_once()
def test_batchable_provider(self) -> None: class SimpleProvider(BatchableMetadataProvider[int]): """ Sets metadata on every pass node to 1 and every return node to 2. """ def visit_Pass(self, node: cst.Pass) -> None: self.set_metadata(node, 1) def visit_Return(self, node: cst.Return) -> None: self.set_metadata(node, 2) wrapper = MetadataWrapper(parse_module("pass; return; pass")) module = wrapper.module pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0] return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1] pass_2 = cast(cst.SimpleStatementLine, module.body[0]).body[2] provider = SimpleProvider() metadata = _gen_batchable(wrapper, [provider]) # Check access on provider self.assertEqual(provider.get_metadata(SimpleProvider, pass_), 1) self.assertEqual(provider.get_metadata(SimpleProvider, return_), 2) self.assertEqual(provider.get_metadata(SimpleProvider, pass_2), 1) # Check returned mapping self.assertEqual(metadata[SimpleProvider][pass_], 1) self.assertEqual(metadata[SimpleProvider][return_], 2) self.assertEqual(metadata[SimpleProvider][pass_2], 1)
def test_batchable_provider_inherited_metadata(self) -> None: """ Tests that batchable providers inherit access to metadata declared by their base classes. """ test_runner = self mock = Mock() class ProviderA(VisitorMetadataProvider[int]): def visit_Pass(self, node: cst.Pass) -> None: mock.visited_a() self.set_metadata(node, 1) class ProviderB(BatchableMetadataProvider[int]): METADATA_DEPENDENCIES = (ProviderA,) class ProviderC(ProviderB): def visit_Pass(self, node: cst.Pass) -> None: mock.visited_c() test_runner.assertEqual(self.get_metadata(ProviderA, node), 1) class VisitorA(CSTTransformer): METADATA_DEPENDENCIES = (ProviderC,) module = parse_module("pass") MetadataWrapper(module).visit(VisitorA()) # Check each visitor is called once mock.visited_a.assert_called_once() mock.visited_c.assert_called_once()
def _collect_statistics( self, modules: Mapping[str, cst.Module]) -> Dict[str, Any]: modules_with_metadata: Mapping[str, cst.MetadataWrapper] = { path: MetadataWrapper(module) for path, module in modules.items() } annotations = _path_wise_counts(modules_with_metadata, AnnotationCountCollector) fixmes = _path_wise_counts(modules, FixmeCountCollector) ignores = _path_wise_counts(modules, IgnoreCountCollector) strict_files = _path_wise_counts( modules, StrictCountCollector, self._configuration.strict, ) return { "annotations": { path: counts.build_json() for path, counts in annotations.items() }, "fixmes": {path: counts.build_json() for path, counts in fixmes.items()}, "ignores": {path: counts.build_json() for path, counts in ignores.items()}, "strict": { path: counts.build_json() for path, counts in strict_files.items() }, }
def test_visitor_provider(self) -> None: class SimpleProvider(VisitorMetadataProvider[int]): """ Sets metadata on every node to 1. """ def on_visit(self, node: cst.CSTNode) -> bool: self.set_metadata(node, 1) return True wrapper = MetadataWrapper(parse_module("pass; return")) module = wrapper.module pass_ = cast(cst.SimpleStatementLine, module.body[0]).body[0] return_ = cast(cst.SimpleStatementLine, module.body[0]).body[1] provider = SimpleProvider() metadata = provider._gen(wrapper) # Check access on provider self.assertEqual(provider.get_metadata(SimpleProvider, module), 1) self.assertEqual(provider.get_metadata(SimpleProvider, pass_), 1) self.assertEqual(provider.get_metadata(SimpleProvider, return_), 1) # Check returned mapping self.assertEqual(metadata[module], 1) self.assertEqual(metadata[pass_], 1) self.assertEqual(metadata[return_], 1)
def get_one_patchable_report_for_path( path: Path, source: bytes, rules: LintRuleCollectionT, use_ignore_byte_markers: bool, use_ignore_comments: bool, metadata_cache: Optional[Mapping["ProviderT", object]], ) -> LintRuleReportsWithAppliedPatches: cst_wrapper: Optional[MetadataWrapper] = None if metadata_cache is not None: cst_wrapper = MetadataWrapper( parse_module(source), True, metadata_cache, ) return lint_file_and_apply_patches( path, source, rules=rules, use_ignore_byte_markers=use_ignore_byte_markers, use_ignore_comments=use_ignore_comments, # We will need to regenerate metadata cache every time a patch is applied. max_iter=1, cst_wrapper=cst_wrapper, find_unused_suppressions=True, )
def test( self, *, source: bytes, rules_in_lint_run: Collection[Type[CstLintRule]], rules_without_report: Collection[Type[CstLintRule]], suppressed_line: int, expected_unused_suppressions_report_messages: Collection[str], expected_replacements: Optional[List[str]] = None, ) -> None: reports = [ CstLintRuleReport( file_path=FILE_PATH, node=cst.EmptyLine(), code=rule.__name__, message="message", line=suppressed_line, column=0, module=cst.MetadataWrapper(cst.parse_module(source)), module_bytes=source, ) for rule in rules_in_lint_run if rule not in rules_without_report ] tokens = _get_tokens(source) ignore_info = IgnoreInfo.compute( comment_info=CommentInfo.compute(tokens=tokens), line_mapping_info=LineMappingInfo.compute(tokens=tokens), ) cst_wrapper = MetadataWrapper(cst.parse_module(source), unsafe_skip_copy=True) config = LintConfig( rule_config={ RemoveUnusedSuppressionsRule.__name__: { "ignore_info": ignore_info, "rules": rules_in_lint_run, } }) unused_suppressions_context = CstContext(cst_wrapper, source, FILE_PATH, config) for report in reports: ignore_info.should_ignore_report(report) _visit_cst_rules_with_context(cst_wrapper, [RemoveUnusedSuppressionsRule], unused_suppressions_context) messages = [] patches = [] for report in unused_suppressions_context.reports: messages.append(report.message) patches.append(report.patch) self.assertEqual(messages, expected_unused_suppressions_report_messages) if expected_replacements is None: self.assertEqual(len(patches), 0) else: self.assertEqual(len(patches), len(expected_replacements)) for idx, patch in enumerate(patches): replacement = patch.apply(source.decode()) self.assertEqual(replacement, expected_replacements[idx])
def handle_any_string( self, node: Union[cst.SimpleString, cst.ConcatenatedString]) -> None: value = node.evaluated_value if value is None: return mod = cst.parse_module(value) extracted_nodes = m.extractall( mod, m.Name( value=m.SaveMatchedNode(m.DoNotCare(), "name"), metadata=m.MatchMetadataIfTrue( cst.metadata.ParentNodeProvider, lambda parent: not isinstance(parent, cst.Attribute), ), ) | m.SaveMatchedNode(m.Attribute(), "attribute"), metadata_resolver=MetadataWrapper(mod, unsafe_skip_copy=True), ) names = { cast(str, values["name"]) for values in extracted_nodes if "name" in values } | { name for values in extracted_nodes if "attribute" in values for name, _ in cst.metadata.scope_provider._gen_dotted_names( cast(cst.Attribute, values["attribute"])) } self.names.update(names)
def get_scope_metadata_provider( module_str: str, ) -> Tuple[cst.Module, Mapping[cst.CSTNode, Scope]]: wrapper = MetadataWrapper(cst.parse_module(dedent(module_str))) return ( wrapper.module, cast(Mapping[cst.CSTNode, Scope], wrapper.resolve( ScopeProvider)), # we're sure every node has an associated scope )
def test_function_position(self) -> None: wrapper = MetadataWrapper(parse_module("def foo():\n pass")) module = wrapper.module positions = wrapper.resolve(PositionProvider) fn = cast(cst.FunctionDef, module.body[0]) stmt = cast(cst.SimpleStatementLine, fn.body.body[0]) pass_stmt = cast(cst.Pass, stmt.body[0]) self.cmp_position(positions[stmt], (2, 4), (2, 8)) self.cmp_position(positions[pass_stmt], (2, 4), (2, 8))
def assert_annotation_count_equal( self, file_content: str, expected_counts: Dict[str, int], ) -> None: module = cst.parse_module(textwrap.dedent(file_content).strip()) annotation_counts = _path_wise_counts( {"test.py": MetadataWrapper(module)}, AnnotationCountCollector) actual_counts = annotation_counts["test.py"].build_json() self.assertEqual(expected_counts, actual_counts)
def test_equal_range(self) -> None: test = self expected_range = CodeRange((1, 4), (1, 6)) class EqualPositionVisitor(CSTVisitor): METADATA_DEPENDENCIES = (PositionProvider,) def visit_Equal(self, node: cst.Equal) -> None: test.assertEqual( self.get_metadata(PositionProvider, node), expected_range ) def visit_NotEqual(self, node: cst.NotEqual) -> None: test.assertEqual( self.get_metadata(PositionProvider, node), expected_range ) MetadataWrapper(parse_module("var == 1")).visit(EqualPositionVisitor()) MetadataWrapper(parse_module("var != 1")).visit(EqualPositionVisitor())
def test_with_as(self) -> None: wrapper = MetadataWrapper(parse_module("with a() as b:\n pass")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.LOAD, "b": ExpressionContext.STORE, }, ))
def test_simple_assign(self) -> None: wrapper = MetadataWrapper(parse_module("a = b")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.STORE, "b": ExpressionContext.LOAD, }, ))
def get_formatted_reports_for_path( path: Path, opts: InsertSuppressionsOpts, metadata_cache: Optional[Mapping["ProviderT", object]] = None, ) -> Iterable[str]: with open(path, "rb") as f: source = f.read() try: cst_wrapper = None if metadata_cache is not None: cst_wrapper = MetadataWrapper( parse_module(source), True, metadata_cache, ) raw_reports = lint_file( path, source, rules={opts.rule}, cst_wrapper=cst_wrapper ) except (SyntaxError, ParserSyntaxError) as e: print_red( f"Encountered the following error while parsing source code in file {path}:" ) print(e) return [] opts_message = opts.message comments = [] for rr in raw_reports: if isinstance(opts_message, str): message = opts_message elif opts_message == MessageKind.USE_LINT_REPORT: message = rr.message else: # opts_message == MessageKind.NO_MESSAGE message = None comments.append( SuppressionComment(opts.kind, rr.line, rr.code, message, opts.max_lines) ) insert_suppressions_result = insert_suppressions(source, comments) updated_source = insert_suppressions_result.updated_source assert ( not insert_suppressions_result.failed_insertions ), "Failed to insert some comments. This should not be possible." if updated_source != source: if not opts.skip_autoformatter: # Format the code using the config file's formatter. updated_source = invoke_formatter( get_lint_config().formatter, updated_source ) with open(path, "wb") as f: f.write(updated_source) # linter completed successfully return [opts.formatter.format(rr) for rr in raw_reports]
def test_starred_element_with_assign(self) -> None: wrapper = MetadataWrapper(parse_module("*a = b")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.LOAD, "b": ExpressionContext.LOAD, }, starred_element_to_context={"a": ExpressionContext.STORE}, ))
def test_del_with_subscript(self) -> None: wrapper = MetadataWrapper(parse_module("del a[b]")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.LOAD, "b": ExpressionContext.LOAD, }, subscript_to_context={"a": ExpressionContext.DEL}, ))
def test_except_as(self) -> None: wrapper = MetadataWrapper( parse_module("try: ...\nexcept Exception as ex:\n pass")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "Exception": ExpressionContext.LOAD, "ex": ExpressionContext.STORE, }, ))
def get_qualified_name_metadata_provider( module_str: str ) -> Tuple[cst.Module, Mapping[cst.CSTNode, Collection[QualifiedName]]]: wrapper = MetadataWrapper(cst.parse_module(dedent(module_str))) return ( wrapper.module, cast( Mapping[cst.CSTNode, Collection[QualifiedName]], wrapper.resolve(QualifiedNameProvider), ), # we're sure every node has an associated scope )
def test_del_with_tuple(self) -> None: wrapper = MetadataWrapper(parse_module("del a, b")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.DEL, "b": ExpressionContext.DEL, }, tuple_to_context={("a", "b"): ExpressionContext.DEL}, ))