def test_assign_with_subscript(self) -> None: wrapper = MetadataWrapper(parse_module("a[b] = c[d]")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.STORE, "b": ExpressionContext.LOAD, "c": ExpressionContext.LOAD, "d": ExpressionContext.LOAD, }, subscript_to_context={ "a[b]": ExpressionContext.STORE, "c[d]": ExpressionContext.LOAD, }, ) ) wrapper = MetadataWrapper(parse_module("x.y[start:end, idx]")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "x": ExpressionContext.LOAD, "y": None, "start": ExpressionContext.LOAD, "end": ExpressionContext.LOAD, "idx": ExpressionContext.LOAD, }, subscript_to_context={"x.y[start:end, idx]": ExpressionContext.LOAD}, attribute_to_context={"x.y": ExpressionContext.LOAD}, ) )
def test_accesses(self) -> None: m, scopes = get_scope_metadata_provider(""" foo = 'toplevel' fn1(foo) fn2(foo) def fn_def(): foo = 'shadow' fn3(foo) """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) global_foo_assignments = list(scope_of_module["foo"]) self.assertEqual(len(global_foo_assignments), 1) foo_assignment = global_foo_assignments[0] self.assertEqual(len(foo_assignment.references), 2) fn1_call_arg = ensure_type( ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr).value, cst.Call, ).args[0] fn2_call_arg = ensure_type( ensure_type( ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr).value, cst.Call, ).args[0] self.assertEqual( {access.node for access in foo_assignment.references}, {fn1_call_arg.value, fn2_call_arg.value}, ) func_body = ensure_type(m.body[3], cst.FunctionDef).body func_foo_statement = func_body.body[0] scope_of_func_statement = scopes[func_foo_statement] self.assertIsInstance(scope_of_func_statement, FunctionScope) func_foo_assignments = scope_of_func_statement["foo"] self.assertEqual(len(func_foo_assignments), 1) foo_assignment = list(func_foo_assignments)[0] self.assertEqual(len(foo_assignment.references), 1) fn3_call_arg = ensure_type( ensure_type( ensure_type(func_body.body[1], cst.SimpleStatementLine).body[0], cst.Expr, ).value, cst.Call, ).args[0] self.assertEqual({access.node for access in foo_assignment.references}, {fn3_call_arg.value}) wrapper = MetadataWrapper(cst.parse_module("from a import b\n")) wrapper.visit(DependentVisitor()) wrapper = MetadataWrapper( cst.parse_module("def a():\n from b import c\n\n")) wrapper.visit(DependentVisitor())
def test_with_as(self) -> None: wrapper = MetadataWrapper(parse_module("with a() as b:\n pass")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.LOAD, "b": ExpressionContext.STORE, }, ))
def test_simple_assign(self) -> None: wrapper = MetadataWrapper(parse_module("a = b")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.STORE, "b": ExpressionContext.LOAD, }, ))
def test_for(self) -> None: wrapper = MetadataWrapper(parse_module("for i in items:\n j = 1")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "i": ExpressionContext.STORE, "items": ExpressionContext.LOAD, "j": ExpressionContext.STORE, }, ))
def test_except_as(self) -> None: wrapper = MetadataWrapper( parse_module("try: ...\nexcept Exception as ex:\n pass")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "Exception": ExpressionContext.LOAD, "ex": ExpressionContext.STORE, }, ))
def test_starred_element_with_assign(self) -> None: wrapper = MetadataWrapper(parse_module("*a = b")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.LOAD, "b": ExpressionContext.LOAD, }, starred_element_to_context={"a": ExpressionContext.STORE}, ))
def test_del_with_tuple(self) -> None: wrapper = MetadataWrapper(parse_module("del a, b")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.DEL, "b": ExpressionContext.DEL, }, tuple_to_context={("a", "b"): ExpressionContext.DEL}, ))
def test_line_count(self) -> None: source_module = MetadataWrapper( cst.parse_module("# No trailing newline")) collector = AnnotationCollector() source_module.visit(collector) self.assertEqual(collector.line_count, 1) source_module = MetadataWrapper( cst.parse_module("# With trailing newline\n")) collector = AnnotationCollector() source_module.visit(collector) self.assertEqual(collector.line_count, 2)
def test_del_with_subscript(self) -> None: wrapper = MetadataWrapper(parse_module("del a[b]")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.LOAD, "b": ExpressionContext.LOAD, }, subscript_to_context={"a": ExpressionContext.DEL}, ))
def test_expressions_with_assign(self) -> None: wrapper = MetadataWrapper(parse_module("f(a)[b] = c")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.LOAD, "b": ExpressionContext.LOAD, "c": ExpressionContext.LOAD, "f": ExpressionContext.LOAD, }, subscript_to_context={"f(a)[b]": ExpressionContext.STORE}, ))
def test_list_with_assing(self) -> None: wrapper = MetadataWrapper(parse_module("[a] = [b]")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.STORE, "b": ExpressionContext.LOAD, }, list_to_context={ ("a", ): ExpressionContext.STORE, ("b", ): ExpressionContext.LOAD, }, ))
def test_function(self) -> None: code = """def foo(x: int = y) -> None: pass""" wrapper = MetadataWrapper(parse_module(code)) wrapper.visit( DependentVisitor( test=self, name_to_context={ "foo": ExpressionContext.STORE, "x": ExpressionContext.STORE, "int": ExpressionContext.LOAD, "y": ExpressionContext.LOAD, "None": ExpressionContext.LOAD, }, ) )
def test_assign_with_subscript(self) -> None: wrapper = MetadataWrapper(parse_module("a[b] = c[d]")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.LOAD, "b": ExpressionContext.LOAD, "c": ExpressionContext.LOAD, "d": ExpressionContext.LOAD, }, subscript_to_context={ "a": ExpressionContext.STORE, "c": ExpressionContext.LOAD, }, ))
def test_assign_to_attribute(self) -> None: wrapper = MetadataWrapper(parse_module("a.b = c.d")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.LOAD, "b": ExpressionContext.STORE, "c": ExpressionContext.LOAD, "d": ExpressionContext.LOAD, }, attribute_to_context={ "b": ExpressionContext.STORE, "d": ExpressionContext.LOAD, }, ))
def test_nested_tuple_with_assign(self) -> None: wrapper = MetadataWrapper(parse_module("((a, b), c) = ((1, 2), 3)")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.STORE, "b": ExpressionContext.STORE, "c": ExpressionContext.STORE, }, tuple_to_context={ "(a, b)": ExpressionContext.STORE, "((a, b), c)": ExpressionContext.STORE, "(1, 2)": ExpressionContext.LOAD, "((1, 2), 3)": ExpressionContext.LOAD, }, ))
def test_class(self) -> None: code = """ class Foo(Bar): x = y """ wrapper = MetadataWrapper(parse_module(dedent(code))) wrapper.visit( DependentVisitor( test=self, name_to_context={ "Foo": ExpressionContext.STORE, "Bar": ExpressionContext.LOAD, "x": ExpressionContext.STORE, "y": ExpressionContext.LOAD, }, ) )
def test_visitor_provider(self) -> None: """ Sets 2 metadata entries for every node: SimpleProvider -> 1 DependentProvider - > 2 """ test = self class DependentVisitor(CSTTransformer): METADATA_DEPENDENCIES = (PositionProvider, ) def visit_Pass(self, node: cst.Pass) -> None: test.assertEqual(self.get_metadata(PositionProvider, node), CodeRange((1, 0), (1, 4))) wrapper = MetadataWrapper(parse_module("pass")) wrapper.visit(DependentVisitor())
def test_walrus(self) -> None: code = """ if x := y: pass """ wrapper = MetadataWrapper( parse_module( dedent(code), config=cst.PartialParserConfig(python_version="3.8") ) ) wrapper.visit( DependentVisitor( test=self, name_to_context={ "x": ExpressionContext.STORE, "y": ExpressionContext.LOAD, }, ) )
def test_nested_list_with_assign(self) -> None: wrapper = MetadataWrapper(parse_module("[[a, b], c] = [[d, e], f]")) wrapper.visit( DependentVisitor( test=self, name_to_context={ "a": ExpressionContext.STORE, "b": ExpressionContext.STORE, "c": ExpressionContext.STORE, "d": ExpressionContext.LOAD, "e": ExpressionContext.LOAD, "f": ExpressionContext.LOAD, }, list_to_context={ "[a, b]": ExpressionContext.STORE, "[[a, b], c]": ExpressionContext.STORE, "[d, e]": ExpressionContext.LOAD, "[[d, e], f]": ExpressionContext.LOAD, }, ))
def refactor_string(source, unused_imports): try: wrapper = MetadataWrapper(cst.parse_module(source)) except cst.ParserSyntaxError as err: print(Color(str(err)).red) else: if unused_imports: fixed_module = wrapper.visit( RemoveUnusedImportTransformer(unused_imports)) return fixed_module.code return source
def refactor_string( source: str, unused_imports: "List[TYPE_IMPORT]", show_error: bool, ) -> str: try: wrapper = MetadataWrapper(cst.parse_module(source)) except cst.ParserSyntaxError as err: if show_error: print(Color(str(err)).red) else: if unused_imports: fixed_module = wrapper.visit( RemoveUnusedImportTransformer(unused_imports)) return fixed_module.code return source
def assert_counts(self, source: str, expected: Dict[str, int]) -> None: source_module = self.format_files(source) source_module = MetadataWrapper(source_module) collector = AnnotationCountCollector() source_module.visit(collector) self.assertEqual(collector.build_json(), expected)
def _build_and_visit_annotation_collector( self, source: str) -> AnnotationCollector: source_module = MetadataWrapper(parse_source(source)) collector = AnnotationCollector() source_module.visit(collector) return collector
def test_invalid_type_for_context(self) -> None: wrapper = MetadataWrapper(parse_module("a()")) wrapper.visit( DependentVisitor(test=self, name_to_context={"a": ExpressionContext.LOAD}))
def test_parent_node_provier(self, code: str) -> None: wrapper = MetadataWrapper(cst.parse_module(dedent(code))) wrapper.visit(DependentVisitor(test=self))
def test_simple_load(self) -> None: wrapper = MetadataWrapper(parse_module("a")) wrapper.visit( DependentVisitor(test=self, name_to_context={"a": ExpressionContext.LOAD}))