def get( wrapper: MetadataWrapper, original_node: cst.CSTNode, replacement_node: Union[cst.CSTNode, cst.RemovalSentinel], ) -> "LintPatch": # Batch the execution of these position providers wrapper.resolve_many( [ ParentNodeProvider, ExperimentalReentrantCodegenProvider, WhitespaceInclusivePositionProvider, ] ) # Use the resolve() API to fetch the data, because it's typed better than # resolve_many() is. parents = wrapper.resolve(ParentNodeProvider) positions = wrapper.resolve(WhitespaceInclusivePositionProvider) codegen_partials = wrapper.resolve(ExperimentalReentrantCodegenProvider) if isinstance(original_node, cst.Module) and isinstance( replacement_node, cst.RemovalSentinel ): raise Exception("Removing the entire module is not possible") # The reentrant codegen provider can only rewrite entire statements at a time, # so we need to inspect our parents until find a statement or the module possible_statement = original_node if isinstance(replacement_node, cst.RemovalSentinel): # reentrant codegen doesn't support RemovalSentinel, so use the parent instead possible_statement = parents[possible_statement] while True: if possible_statement in codegen_partials: partial = codegen_partials[possible_statement] patched_statement = cst.ensure_type( _replace_or_remove( possible_statement, original_node, replacement_node ), cst.BaseStatement, ) original_str = partial.get_original_statement_code() patched_str = partial.get_modified_statement_code(patched_statement) return LintPatch( partial.start_offset, positions[possible_statement].start, original_str, patched_str, ) elif possible_statement in parents: possible_statement = parents[possible_statement] else: # There's no more parents, so we have to fall back to replacing the whole # module. original_str = wrapper.module.code patched_module = cst.ensure_type( _replace_or_remove(wrapper.module, original_node, replacement_node), cst.Module, ) patched_str = patched_module.code return LintPatch(0, CodePosition(1, 0), original_str, patched_str)
def test_metadata_cache(self) -> None: class DummyMetadataProvider(BatchableMetadataProvider[None]): gen_cache = tuple m = cst.parse_module("pass") mw = MetadataWrapper(m) with self.assertRaisesRegex( Exception, "Cache is required for initializing DummyMetadataProvider."): mw.resolve(DummyMetadataProvider) class SimpleCacheMetadataProvider(BatchableMetadataProvider[object]): gen_cache = tuple def __init__(self, cache: object) -> None: super().__init__(cache) self.cache = cache def visit_Pass(self, node: cst.Pass) -> Optional[bool]: self.set_metadata(node, self.cache) cached_data = object() # pyre-fixme[6]: Expected `Mapping[Type[BaseMetadataProvider[object]], # object]` for 2nd param but got `Dict[Type[SimpleCacheMetadataProvider], # object]`. mw = MetadataWrapper(m, cache={SimpleCacheMetadataProvider: cached_data}) pass_node = cst.ensure_type(mw.module.body[0], cst.SimpleStatementLine).body[0] self.assertEqual( mw.resolve(SimpleCacheMetadataProvider)[pass_node], cached_data)
def get_scope_metadata_provider( module_str: str, ) -> Tuple[cst.Module, Mapping[cst.CSTNode, Scope]]: wrapper = MetadataWrapper(cst.parse_module(dedent(module_str))) return ( wrapper.module, cast(Mapping[cst.CSTNode, Scope], wrapper.resolve( ScopeProvider)), # we're sure every node has an associated scope )
def test_resolve_provider_twice(self) -> None: """ Tests that resolving the same provider twice is a no-op """ mock = Mock() class ProviderA(VisitorMetadataProvider[bool]): def visit_Pass(self, node: cst.Pass) -> None: mock.visited_a() module = cst.parse_module("pass") wrapper = MetadataWrapper(module) wrapper.resolve(ProviderA) mock.visited_a.assert_called_once() wrapper.resolve(ProviderA) mock.visited_a.assert_called_once()
def test_function_position(self) -> None: wrapper = MetadataWrapper(parse_module("def foo():\n pass")) module = wrapper.module positions = wrapper.resolve(PositionProvider) fn = cast(cst.FunctionDef, module.body[0]) stmt = cast(cst.SimpleStatementLine, fn.body.body[0]) pass_stmt = cast(cst.Pass, stmt.body[0]) self.cmp_position(positions[stmt], (2, 4), (2, 8)) self.cmp_position(positions[pass_stmt], (2, 4), (2, 8))
def get_qualified_name_metadata_provider( module_str: str ) -> Tuple[cst.Module, Mapping[cst.CSTNode, Collection[QualifiedName]]]: wrapper = MetadataWrapper(cst.parse_module(dedent(module_str))) return ( wrapper.module, cast( Mapping[cst.CSTNode, Collection[QualifiedName]], wrapper.resolve(QualifiedNameProvider), ), # we're sure every node has an associated scope )
def test_multiline_string_position(self) -> None: wrapper = MetadataWrapper(parse_module('"abc"\\\n"def"')) module = wrapper.module positions = wrapper.resolve(PositionProvider) stmt = cast(cst.SimpleStatementLine, module.body[0]) expr = cast(cst.Expr, stmt.body[0]) string = expr.value self.cmp_position(positions[stmt], (1, 0), (2, 5)) self.cmp_position(positions[expr], (1, 0), (2, 5)) self.cmp_position(positions[string], (1, 0), (2, 5))
def test_byte_conversion(self, ) -> None: module_bytes = "fn()\n".encode("utf-16") mw = MetadataWrapper( cst.parse_module("fn()\n", cst.PartialParserConfig(encoding="utf-16"))) codegen_partial = mw.resolve(ExperimentalReentrantCodegenProvider)[ mw.module.body[0]] self.assertEqual(codegen_partial.get_original_module_bytes(), module_bytes) self.assertEqual( codegen_partial.get_modified_module_bytes( cst.parse_statement("fn2()\n")), "fn2()\n".encode("utf-16"), )
def test_provider( self, old_module: str, new_module: str, old_node: Callable[[cst.Module], cst.CSTNode], new_node: cst.BaseStatement, ) -> None: old_module = dedent(old_module) new_module = dedent(new_module) mw = MetadataWrapper(cst.parse_module(old_module)) codegen_partial = mw.resolve(ExperimentalReentrantCodegenProvider)[ old_node(mw.module) ] self.assertEqual(codegen_partial.get_original_module_code(), old_module) self.assertEqual(codegen_partial.get_modified_module_code(new_node), new_module)
def test_nested_indent_position(self) -> None: wrapper = MetadataWrapper( parse_module("if True:\n if False:\n x = 1\nelse:\n return") ) module = wrapper.module positions = wrapper.resolve(PositionProvider) outer_if = cast(cst.If, module.body[0]) inner_if = cast(cst.If, outer_if.body.body[0]) assign = cast(cst.SimpleStatementLine, inner_if.body.body[0]).body[0] outer_else = cast(cst.Else, outer_if.orelse) return_stmt = cast(cst.SimpleStatementLine, outer_else.body.body[0]).body[0] self.cmp_position(positions[outer_if], (1, 0), (5, 10)) self.cmp_position(positions[inner_if], (2, 4), (3, 13)) self.cmp_position(positions[assign], (3, 8), (3, 13)) self.cmp_position(positions[outer_else], (4, 0), (5, 10)) self.cmp_position(positions[return_stmt], (5, 4), (5, 10))
def test_module_position(self, *, code: str, expected: CodeRange) -> None: wrapper = MetadataWrapper(parse_module(code)) positions = wrapper.resolve(PositionProvider) self.assertEqual(positions[wrapper.module], expected)
def get_qualified_name_metadata_provider( module_str: str, ) -> Tuple[cst.Module, Mapping[cst.CSTNode, Collection[QualifiedName]]]: wrapper = MetadataWrapper(cst.parse_module(dedent(module_str))) return wrapper.module, wrapper.resolve(QualifiedNameProvider)