def test_class_scope(self) -> None: m, scopes = get_scope_metadata_provider(""" global_var = None @cls_attr class Cls(cls_attr, kwarg=cls_attr): cls_attr = 5 def f(): pass """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) cls_assignments = scope_of_module["Cls"] self.assertEqual(len(cls_assignments), 1) cls_assignment = cast(Assignment, cls_assignments[0]) cls_def = ensure_type(m.body[1], cst.ClassDef) self.assertEqual(cls_assignment.node, cls_def) cls_body = cls_def.body cls_body_statement = cls_body.body[0] scope_of_class = scopes[cls_body_statement] self.assertIsInstance(scope_of_class, ClassScope) func_body = ensure_type(cls_body.body[1], cst.FunctionDef).body func_body_statement = func_body.body[0] scope_of_func = scopes[func_body_statement] self.assertIsInstance(scope_of_func, FunctionScope) self.assertTrue("global_var" in scope_of_module) self.assertTrue("global_var" in scope_of_class) self.assertTrue("global_var" in scope_of_func) self.assertTrue("Cls" in scope_of_module) self.assertTrue("Cls" in scope_of_class) self.assertTrue("Cls" in scope_of_func) self.assertTrue("cls_attr" not in scope_of_module) self.assertTrue("cls_attr" in scope_of_class) self.assertTrue("cls_attr" not in scope_of_func)
def test_nested_comprehension_scope(self) -> None: m, scopes = get_scope_metadata_provider(""" [y for x in iterator for y in x] """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) list_comp = ensure_type( ensure_type( ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.Expr).value, cst.ListComp, ) scope_of_list_comp = scopes[list_comp.elt] self.assertIsInstance(scope_of_list_comp, ComprehensionScope) self.assertIs(scopes[list_comp], scope_of_module) self.assertIs(scopes[list_comp.elt], scope_of_list_comp) self.assertIs(scopes[list_comp.for_in], scope_of_module) self.assertIs(scopes[list_comp.for_in.iter], scope_of_module) self.assertIs(scopes[list_comp.for_in.target], scope_of_list_comp) inner_for_in = ensure_type(list_comp.for_in.inner_for_in, cst.CompFor) self.assertIs(scopes[inner_for_in], scope_of_list_comp) self.assertIs(scopes[inner_for_in.iter], scope_of_list_comp) self.assertIs(scopes[inner_for_in.target], scope_of_list_comp)
def test_multiple_assignments(self) -> None: m, scopes = get_scope_metadata_provider(""" if 1: from a import b as c elif 2: from d import e as c c() """) call = ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr).value scope = scopes[call] self.assertIsInstance(scope, GlobalScope) self.assertEqual( scope.get_qualified_names_for(call), { QualifiedName(name="a.b", source=QualifiedNameSource.IMPORT), QualifiedName(name="d.e", source=QualifiedNameSource.IMPORT), }, ) self.assertEqual( scope.get_qualified_names_for("c"), { QualifiedName(name="a.b", source=QualifiedNameSource.IMPORT), QualifiedName(name="d.e", source=QualifiedNameSource.IMPORT), }, )
def test_with_statement(self) -> None: m, scopes = get_scope_metadata_provider(""" import unittest.mock with unittest.mock.patch("something") as obj: obj.f1() unittest.mock """) import_ = ensure_type(m.body[0], cst.SimpleStatementLine).body[0] assignments = scopes[import_]["unittest"] self.assertEqual(len(assignments), 1) self.assertEqual(cast(Assignment, list(assignments)[0]).node, import_) with_ = ensure_type(m.body[1], cst.With) fn_call = with_.items[0].item self.assertEqual( scopes[fn_call].get_qualified_names_for(fn_call), { QualifiedName(name="unittest.mock.patch", source=QualifiedNameSource.IMPORT) }, ) mock = ensure_type( ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr).value self.assertEqual( scopes[fn_call].get_qualified_names_for(mock), { QualifiedName(name="unittest.mock", source=QualifiedNameSource.IMPORT) }, )
def test_extract_simple(self) -> None: # Verify true behavior expression = cst.parse_expression("a + b[c], d(e, f * g)") nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Name(), "left"))), m.Element(m.Call()), ]), ) extracted_node = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left self.assertEqual(nodes, {"left": extracted_node}) # Verify false behavior nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation( left=m.SaveMatchedNode(m.Subscript(), "left"))), m.Element(m.Call()), ]), ) self.assertIsNone(nodes)
def test_dotted_import_access(self) -> None: m, scopes = get_scope_metadata_provider(""" import a.b.c, x.y a.b.c(x.z) """) scope_of_module = scopes[m] first_statement = ensure_type(m.body[1], cst.SimpleStatementLine) call = ensure_type( ensure_type(first_statement.body[0], cst.Expr).value, cst.Call) self.assertTrue("a.b.c" in scope_of_module) self.assertTrue("a" in scope_of_module) self.assertEqual(scope_of_module.accesses["a"], set()) a_b_c_assignment = cast(Assignment, list(scope_of_module["a.b.c"])[0]) a_b_c_access = list(a_b_c_assignment.references)[0] self.assertEqual(scope_of_module.accesses["a.b.c"], {a_b_c_access}) self.assertEqual(a_b_c_access.node, call.func) x_assignment = cast(Assignment, list(scope_of_module["x"])[0]) x_access = list(x_assignment.references)[0] self.assertEqual(scope_of_module.accesses["x"], {x_access}) self.assertEqual(x_access.node, ensure_type(call.args[0].value, cst.Attribute).value) self.assertTrue("x.y" in scope_of_module) self.assertEqual(list(scope_of_module["x.y"])[0].references, set()) self.assertEqual(scope_of_module.accesses["x.y"], set())
def test_extract_sequence_element(self) -> None: # Verify true behavior expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)") nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[m.SaveMatchedNode(m.ZeroOrMore(), "args")])), ]), ) extracted_seq = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).args self.assertEqual(nodes, {"args": extracted_seq}) # Verify false behavior nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=[ m.SaveMatchedNode(m.ZeroOrMore(m.Arg(m.Subscript())), "args") ])), ]), ) self.assertIsNone(nodes)
def test_local_scope_shadowing_with_functions(self) -> None: m, scopes = get_scope_metadata_provider( """ def f(): def f(): f = ... """ ) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) self.assertTrue("f" in scope_of_module) outer_f = ensure_type(m.body[0], cst.FunctionDef) scope_of_outer_f = scopes[outer_f.body.body[0]] self.assertIsInstance(scope_of_outer_f, FunctionScope) self.assertTrue("f" in scope_of_outer_f) out_f_assignment = scope_of_module["f"][0] self.assertEqual(cast(Assignment, out_f_assignment).node, outer_f) inner_f = ensure_type(outer_f.body.body[0], cst.FunctionDef) scope_of_inner_f = scopes[inner_f.body.body[0]] self.assertIsInstance(scope_of_inner_f, FunctionScope) self.assertTrue("f" in scope_of_inner_f) inner_f_assignment = scope_of_outer_f["f"][0] self.assertEqual(cast(Assignment, inner_f_assignment).node, inner_f)
def _extract_static_bool(cls, node: cst.BaseExpression) -> Optional[bool]: if m.matches(node, m.Call()): # cannot reason about function calls return None if m.matches(node, m.UnaryOperation(operator=m.Not())): sub_value = cls._extract_static_bool( cst.ensure_type(node, cst.UnaryOperation).expression) if sub_value is None: return None return not sub_value if m.matches(node, m.Name("True")): return True if m.matches(node, m.Name("False")): return False if m.matches(node, m.BooleanOperation()): node = cst.ensure_type(node, cst.BooleanOperation) left_value = cls._extract_static_bool(node.left) right_value = cls._extract_static_bool(node.right) if m.matches(node.operator, m.Or()): if right_value is True or left_value is True: return True if m.matches(node.operator, m.And()): if right_value is False or left_value is False: return False return None
def _get_clean_type(typeobj: object) -> str: """ Given a type object as returned by dataclasses, sanitize it and convert it to a type string that is appropriate for our codegen below. """ # First, get the type as a parseable expression. typestr = repr(typeobj) if typestr.startswith("<class '") and typestr.endswith("'>"): typestr = typestr[8:-2] # Now, parse the expression with LibCST. cleanser = CleanseFullTypeNames() typecst = parse_expression(typestr) typecst = typecst.visit(cleanser) clean_type: Optional[cst.CSTNode] = None # Now, convert the type to allow for DoNotCareSentinel values. if isinstance(typecst, cst.Subscript): if typecst.value.deep_equals(cst.Name("Union")): # We can modify this as-is to add our type clean_type = typecst.with_changes( slice=[*typecst.slice, _get_do_not_care()] ) elif typecst.value.deep_equals(cst.Name("Literal")): clean_type = _get_wrapped_union_type(typecst, _get_do_not_care()) elif typecst.value.deep_equals(cst.Name("Sequence")): clean_type = _get_wrapped_union_type(typecst, _get_do_not_care()) elif isinstance(typecst, (cst.Name, cst.SimpleString)): clean_type = _get_wrapped_union_type(typecst, _get_do_not_care()) # Now, clean up the outputted type and return the code it generates. If # for some reason we encounter a new node type, raise so we can triage. if clean_type is None: raise Exception(f"Don't support {typecst}") else: # First, add DoNotCareSentinel to all sequences, so that a sequence # can be defined partially with explicit DoNotCare() values for some # slots. clean_type = ensure_type( clean_type.visit(AddDoNotCareToSequences()), cst.CSTNode ) # Now, double-quote any types we parsed and repr'd, for consistency. clean_type = ensure_type(clean_type.visit(DoubleQuoteStrings()), cst.CSTNode) # Now, insert OneOf/AllOf and MatchIfTrue into unions so we can typecheck their usage. # This allows us to put OneOf[SomeType] or MatchIfTrue[cst.SomeType] into any # spot that we would have originally allowed a SomeType. clean_type = ensure_type( clean_type.visit(AddLogicAndLambdaMatcherToUnions()), cst.CSTNode ) # Now, insert AtMostN and AtLeastN into sequence unions, so we can typecheck # them. This relies on the previous OneOf/AllOf insertion to ensure that all # sequences we care about are Sequence[Union[<x>]]. clean_type = ensure_type( clean_type.visit(AddWildcardsToSequenceUnions()), cst.CSTNode ) # Finally, generate the code given a default Module so we can spit it out. return cst.Module(body=()).code_for_node(clean_type)
def test_accesses(self) -> None: m, scopes = get_scope_metadata_provider(""" foo = 'toplevel' fn1(foo) fn2(foo) def fn_def(): foo = 'shadow' fn3(foo) """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) global_foo_assignments = list(scope_of_module["foo"]) self.assertEqual(len(global_foo_assignments), 1) foo_assignment = global_foo_assignments[0] self.assertEqual(len(foo_assignment.references), 2) fn1_call_arg = ensure_type( ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr).value, cst.Call, ).args[0] fn2_call_arg = ensure_type( ensure_type( ensure_type(m.body[2], cst.SimpleStatementLine).body[0], cst.Expr).value, cst.Call, ).args[0] self.assertEqual( {access.node for access in foo_assignment.references}, {fn1_call_arg.value, fn2_call_arg.value}, ) func_body = ensure_type(m.body[3], cst.FunctionDef).body func_foo_statement = func_body.body[0] scope_of_func_statement = scopes[func_foo_statement] self.assertIsInstance(scope_of_func_statement, FunctionScope) func_foo_assignments = scope_of_func_statement["foo"] self.assertEqual(len(func_foo_assignments), 1) foo_assignment = list(func_foo_assignments)[0] self.assertEqual(len(foo_assignment.references), 1) fn3_call_arg = ensure_type( ensure_type( ensure_type(func_body.body[1], cst.SimpleStatementLine).body[0], cst.Expr, ).value, cst.Call, ).args[0] self.assertEqual({access.node for access in foo_assignment.references}, {fn3_call_arg.value}) wrapper = MetadataWrapper(cst.parse_module("from a import b\n")) wrapper.visit(DependentVisitor()) wrapper = MetadataWrapper( cst.parse_module("def a():\n from b import c\n\n")) wrapper.visit(DependentVisitor())
def test_keyword_arg_in_call(self) -> None: m, scopes = get_scope_metadata_provider("call(arg=val)") call = ensure_type( ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.Expr).value scope = scopes[call] self.assertIsInstance(scope, GlobalScope) self.assertEqual(len(scope["arg"]), 0) # no assignment should exist
def clean_generated_code(code: str) -> str: """ Generalized sanity clean-up for all codegen so we can fix issues such as Union[SingleType]. The transforms found here are strictly for form and do not affect functionality. """ module = parse_module(code) module = ensure_type(module.visit(SimplifyUnionsTransformer()), cst.Module) module = ensure_type(module.visit(DoubleQuoteForwardRefsTransformer()), cst.Module) return module.code
def test_extractall_simple(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)") matches = extractall(expression, m.Arg(m.SaveMatchedNode(~m.Name(), "expr"))) extracted_args = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call ).args self.assertEqual( matches, [{"expr": extracted_args[1].value}, {"expr": extracted_args[2].value}], )
def _test_simple_class_helper(test: UnitTest, wrapper: MetadataWrapper) -> None: types = wrapper.resolve(TypeInferenceProvider) m = wrapper.module assign = cst.ensure_type( cst.ensure_type( cst.ensure_type( cst.ensure_type(m.body[1].body, cst.IndentedBlock).body[0], cst.FunctionDef, ).body.body[0], cst.SimpleStatementLine, ).body[0], cst.AnnAssign, ) self_number_attr = cst.ensure_type(assign.target, cst.Attribute) test.assertEqual(types[self_number_attr], "int") value = assign.value if value: test.assertEqual(types[value], "int") # self test.assertEqual(types[self_number_attr.value], "simple_class.Item") collector_assign = cst.ensure_type( cst.ensure_type(m.body[3], cst.SimpleStatementLine).body[0], cst.Assign) collector = collector_assign.targets[0].target test.assertEqual(types[collector], "simple_class.ItemCollector") items_assign = cst.ensure_type( cst.ensure_type(m.body[4], cst.SimpleStatementLine).body[0], cst.AnnAssign) items = items_assign.target test.assertEqual(types[items], "typing.Sequence[simple_class.Item]")
def _replace_nested( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return cst.ensure_type(node, cst.Call).with_changes(args=[ cst.Arg( cst.Name(value=cst.ensure_type( cst.ensure_type(extraction["inner"], cst.Call).func, cst.Name, ).value + "_immediate")) ])
def test_parse_import_simple(self) -> None: node = util.parse_import("import a") self.assertEqual( cst.ensure_type( cst.ensure_type( cst.ensure_type(node, cst.SimpleStatementLine).body[0], cst.Import, ).names[0], cst.ImportAlias, ).name.value, "a", )
def test_nested_qualified_names(self) -> None: m, names = get_qualified_name_metadata_provider( """ class A: def f1(self): def f2(): pass f2() def f3(self): class B(): ... B() def f4(): def f5(): class C: pass C() f5() """ ) cls_a = ensure_type(m.body[0], cst.ClassDef) self.assertEqual(names[cls_a], {QualifiedName("A", QualifiedNameSource.LOCAL)}) func_f1 = ensure_type(cls_a.body.body[0], cst.FunctionDef) self.assertEqual( names[func_f1], {QualifiedName("A.f1", QualifiedNameSource.LOCAL)} ) func_f2_call = ensure_type( ensure_type(func_f1.body.body[1], cst.SimpleStatementLine).body[0], cst.Expr ).value self.assertEqual( names[func_f2_call], {QualifiedName("A.f1.<locals>.f2", QualifiedNameSource.LOCAL)}, ) func_f3 = ensure_type(cls_a.body.body[1], cst.FunctionDef) self.assertEqual( names[func_f3], {QualifiedName("A.f3", QualifiedNameSource.LOCAL)} ) call_b = ensure_type( ensure_type(func_f3.body.body[1], cst.SimpleStatementLine).body[0], cst.Expr ).value self.assertEqual( names[call_b], {QualifiedName("A.f3.<locals>.B", QualifiedNameSource.LOCAL)} ) func_f4 = ensure_type(m.body[1], cst.FunctionDef) self.assertEqual( names[func_f4], {QualifiedName("f4", QualifiedNameSource.LOCAL)} ) func_f5 = ensure_type(func_f4.body.body[0], cst.FunctionDef) self.assertEqual( names[func_f5], {QualifiedName("f4.<locals>.f5", QualifiedNameSource.LOCAL)} ) cls_c = func_f5.body.body[0] self.assertEqual( names[cls_c], {QualifiedName("f4.<locals>.f5.<locals>.C", QualifiedNameSource.LOCAL)}, )
def _add_one_to_arg( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return node.deep_replace( # This can be either a node or a sequence, pyre doesn't know. cst.ensure_type(extraction["arg"], cst.CSTNode), # Grab the arg and add one to its value. cst.Integer( str( int(cst.ensure_type(extraction["arg"], cst.Integer).value) + 1)), )
def _make_fixture( self, code: str) -> Tuple[cst.BaseExpression, meta.MetadataWrapper]: module = cst.parse_module(dedent(code)) wrapper = cst.MetadataWrapper(module) return ( cst.ensure_type( cst.ensure_type(wrapper.module.body[0], cst.SimpleStatementLine).body[0], cst.Expr, ).value, wrapper, )
def test_deep_replace_simple(self) -> None: old_code = """ pass """ new_code = """ break """ module = cst.parse_module(dedent(old_code)) pass_stmt = cst.ensure_type(module.body[0], cst.SimpleStatementLine).body[0] new_module = cst.ensure_type( module.deep_replace(pass_stmt, cst.Break()), cst.Module) self.assertEqual(new_module.code, dedent(new_code))
def visit_Call(self, node: cst.Call) -> None: if m.matches( node, m.Call( func=m.Name("list") | m.Name("set") | m.Name("dict"), args=[m.Arg(value=m.GeneratorExp() | m.ListComp())], ), ): call_name = cst.ensure_type(node.func, cst.Name).value if m.matches(node.args[0].value, m.GeneratorExp()): exp = cst.ensure_type(node.args[0].value, cst.GeneratorExp) message_formatter = UNNECESSARY_GENERATOR else: exp = cst.ensure_type(node.args[0].value, cst.ListComp) message_formatter = UNNECESSARY_LIST_COMPREHENSION replacement = None if call_name == "list": replacement = node.deep_replace( node, cst.ListComp(elt=exp.elt, for_in=exp.for_in)) elif call_name == "set": replacement = node.deep_replace( node, cst.SetComp(elt=exp.elt, for_in=exp.for_in)) elif call_name == "dict": elt = exp.elt key = None value = None if m.matches(elt, m.Tuple(m.DoNotCare(), m.DoNotCare())): elt = cst.ensure_type(elt, cst.Tuple) key = elt.elements[0].value value = elt.elements[1].value elif m.matches(elt, m.List(m.DoNotCare(), m.DoNotCare())): elt = cst.ensure_type(elt, cst.List) key = elt.elements[0].value value = elt.elements[1].value else: # Unrecoginized form return replacement = node.deep_replace( node, # pyre-fixme[6]: Expected `BaseAssignTargetExpression` for 1st # param but got `BaseExpression`. cst.DictComp(key=key, value=value, for_in=exp.for_in), ) self.report(node, message_formatter.format(func=call_name), replacement=replacement)
def test_with_asname(self) -> None: m, scopes = get_scope_metadata_provider(""" with open(file_name) as f: ... """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) self.assertTrue("f" in scope_of_module) self.assertEqual( cast(Assignment, scope_of_module["f"][0]).node, ensure_type( ensure_type(m.body[0], cst.With).items[0].asname, cst.AsName).name, )
def test_annotation_access(self) -> None: m, scopes = get_scope_metadata_provider(""" from typing import Literal, TypeVar from a import A, B, C, D, E, F def x(a: A): pass def y(b: "B"): pass def z(c: Literal["C"]): pass DType = TypeVar("DType", bound=D) EType = TypeVar("EType", bound="E") FType = TypeVar("F") """) imp = ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.ImportFrom) scope = scopes[imp] assignment = list(scope["A"])[0] self.assertIsInstance(assignment, Assignment) self.assertEqual(len(assignment.references), 1) references = list(assignment.references) self.assertTrue(references[0].is_annotation) assignment = list(scope["B"])[0] self.assertIsInstance(assignment, Assignment) self.assertEqual(len(assignment.references), 1) references = list(assignment.references) self.assertTrue(references[0].is_annotation) assignment = list(scope["C"])[0] self.assertIsInstance(assignment, Assignment) self.assertEqual(len(assignment.references), 0) assignment = list(scope["D"])[0] self.assertIsInstance(assignment, Assignment) self.assertEqual(len(assignment.references), 1) references = list(assignment.references) self.assertTrue(references[0].is_annotation) assignment = list(scope["E"])[0] self.assertIsInstance(assignment, Assignment) self.assertEqual(len(assignment.references), 1) references = list(assignment.references) self.assertTrue(references[0].is_annotation) assignment = list(scope["F"])[0] self.assertIsInstance(assignment, Assignment) self.assertEqual(len(assignment.references), 0)
def test_extract_sequence(self) -> None: expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)") nodes = m.extract( expression, m.Tuple(elements=[ m.DoNotCare(), m.Element( m.Call(args=m.SaveMatchedNode([m.ZeroOrMore()], "args"))), ]), ) extracted_seq = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call).args self.assertEqual(nodes, {"args": extracted_seq})
def leave_With( self, original_node: cst.With, updated_node: cst.With ) -> Union[cst.BaseStatement, cst.RemovalSentinel]: candidate_with: cst.With = original_node compound_items: List[cst.WithItem] = [] final_body: cst.BaseSuite = candidate_with.body while True: # There is no way to meaningfully represent comments inside # multi-line `with` statements due to how Python grammar is # written, so we do not try to transform such `with` statements # lest we lose something important in the comments. if has_leading_comment(candidate_with): break if has_inline_comment(candidate_with.body): break # There is no meaningful way `async with` can be merged into # the compound `with` statement. if candidate_with.asynchronous: break compound_items.extend(candidate_with.items) final_body = candidate_with.body if not isinstance(final_body.body[0], cst.With): break if len(final_body.body) > 1: break candidate_with = cst.ensure_type(candidate_with.body.body[0], cst.With) if len(compound_items) <= 1: return original_node final_body = cst.ensure_type(final_body, cst.IndentedBlock) topmost_body = cst.ensure_type(original_node.body, cst.IndentedBlock) if has_footer_comment( topmost_body) and not has_footer_comment(final_body): final_body = final_body.with_changes(footer=(*final_body.footer, *topmost_body.footer)) return updated_node.with_changes(body=final_body, items=compound_items)
def test_func_param_scope(self) -> None: m, scopes = get_scope_metadata_provider(""" @decorator def f(x: T=1, *vararg, y: T=2, z, **kwarg) -> RET: pass """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) self.assertTrue("f" in scope_of_module) f = ensure_type(m.body[0], cst.FunctionDef) scope_of_f = scopes[f.body.body[0]] self.assertIsInstance(scope_of_f, FunctionScope) decorator = f.decorators[0] x = f.params.params[0] xT = ensure_type(x.annotation, cst.Annotation) one = ensure_type(x.default, cst.BaseExpression) vararg = ensure_type(f.params.star_arg, cst.Param) y = f.params.kwonly_params[0] yT = ensure_type(y.annotation, cst.Annotation) two = ensure_type(y.default, cst.BaseExpression) z = f.params.kwonly_params[1] kwarg = ensure_type(f.params.star_kwarg, cst.Param) ret = ensure_type(f.returns, cst.Annotation).annotation self.assertEqual(scopes[decorator], scope_of_module) self.assertEqual(scopes[x], scope_of_f) self.assertEqual(scopes[xT], scope_of_module) self.assertEqual(scopes[one], scope_of_module) self.assertEqual(scopes[vararg], scope_of_f) self.assertEqual(scopes[y], scope_of_f) self.assertEqual(scopes[yT], scope_of_module) self.assertEqual(scopes[z], scope_of_f) self.assertEqual(scopes[two], scope_of_module) self.assertEqual(scopes[kwarg], scope_of_f) self.assertEqual(scopes[ret], scope_of_module) self.assertTrue("x" not in scope_of_module) self.assertTrue("x" in scope_of_f) self.assertTrue("vararg" not in scope_of_module) self.assertTrue("vararg" in scope_of_f) self.assertTrue("y" not in scope_of_module) self.assertTrue("y" in scope_of_f) self.assertTrue("z" not in scope_of_module) self.assertTrue("z" in scope_of_f) self.assertTrue("kwarg" not in scope_of_module) self.assertTrue("kwarg" in scope_of_f) self.assertEqual(cast(Assignment, list(scope_of_f["x"])[0]).node, x) self.assertEqual( cast(Assignment, list(scope_of_f["vararg"])[0]).node, vararg) self.assertEqual(cast(Assignment, list(scope_of_f["y"])[0]).node, y) self.assertEqual(cast(Assignment, list(scope_of_f["z"])[0]).node, z) self.assertEqual( cast(Assignment, list(scope_of_f["kwarg"])[0]).node, kwarg)
def test_extract_metadata(self) -> None: # Verify true behavior module = cst.parse_module("a + b[c], d(e, f * g)") wrapper = cst.MetadataWrapper(module) expression = cst.ensure_type( cst.ensure_type(wrapper.module.body[0], cst.SimpleStatementLine).body[0], cst.Expr, ).value nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation(left=m.Name(metadata=m.SaveMatchedNode( m.MatchMetadata( meta.PositionProvider, self._make_coderange((1, 0), (1, 1)), ), "left", )))), m.Element(m.Call()), ]), metadata_resolver=wrapper, ) extracted_node = cst.ensure_type( cst.ensure_type(expression, cst.Tuple).elements[0].value, cst.BinaryOperation, ).left self.assertEqual(nodes, {"left": extracted_node}) # Verify false behavior nodes = m.extract( expression, m.Tuple(elements=[ m.Element( m.BinaryOperation(left=m.Name(metadata=m.SaveMatchedNode( m.MatchMetadata( meta.PositionProvider, self._make_coderange((1, 0), (1, 2)), ), "left", )))), m.Element(m.Call()), ]), metadata_resolver=wrapper, ) self.assertIsNone(nodes)
def test_imoprt_from(self) -> None: m, scopes = get_scope_metadata_provider( """ from foo.bar import a, b as b_renamed from . import c from .foo import d """ ) scope_of_module = scopes[m] for idx, in_scope in [(0, "a"), (0, "b_renamed"), (1, "c"), (2, "d")]: self.assertEqual( len(scope_of_module[in_scope]), 1, f"{in_scope} should be in scope." ) import_assignment = cast(Assignment, scope_of_module[in_scope][0]) self.assertEqual( import_assignment.name, in_scope, f"The name of Assignment {import_assignment.name} should equal to {in_scope}.", ) import_node = ensure_type(m.body[idx], cst.SimpleStatementLine).body[0] self.assertEqual( import_assignment.node, import_node, f"The node of Assignment {import_assignment.node} should equal to {import_node}", ) for not_in_scope in ["foo", "bar", "foo.bar", "b"]: self.assertEqual( len(scope_of_module[not_in_scope]), 0, f"{not_in_scope} should not be in scope.", )
def test_import(self) -> None: m, scopes = get_scope_metadata_provider( """ import foo.bar import fizz.buzz as fizzbuzz import a.b.c import d.e.f as g """ ) scope_of_module = scopes[m] for idx, in_scope in enumerate(["foo", "fizzbuzz", "a", "g"]): self.assertEqual( len(scope_of_module[in_scope]), 1, f"{in_scope} should be in scope." ) assignment = cast(Assignment, scope_of_module[in_scope][0]) self.assertEqual( assignment.name, in_scope, f"Assignment name {assignment.name} should equal to {in_scope}.", ) import_node = ensure_type(m.body[idx], cst.SimpleStatementLine).body[0] self.assertEqual( assignment.node, import_node, f"The node of Assignment {assignment.node} should equal to {import_node}", )