def test_class_scope(self) -> None:
     m, scopes = get_scope_metadata_provider("""
         global_var = None
         @cls_attr
         class Cls(cls_attr, kwarg=cls_attr):
             cls_attr = 5
             def f():
                 pass
         """)
     scope_of_module = scopes[m]
     self.assertIsInstance(scope_of_module, GlobalScope)
     cls_assignments = scope_of_module["Cls"]
     self.assertEqual(len(cls_assignments), 1)
     cls_assignment = cast(Assignment, cls_assignments[0])
     cls_def = ensure_type(m.body[1], cst.ClassDef)
     self.assertEqual(cls_assignment.node, cls_def)
     cls_body = cls_def.body
     cls_body_statement = cls_body.body[0]
     scope_of_class = scopes[cls_body_statement]
     self.assertIsInstance(scope_of_class, ClassScope)
     func_body = ensure_type(cls_body.body[1], cst.FunctionDef).body
     func_body_statement = func_body.body[0]
     scope_of_func = scopes[func_body_statement]
     self.assertIsInstance(scope_of_func, FunctionScope)
     self.assertTrue("global_var" in scope_of_module)
     self.assertTrue("global_var" in scope_of_class)
     self.assertTrue("global_var" in scope_of_func)
     self.assertTrue("Cls" in scope_of_module)
     self.assertTrue("Cls" in scope_of_class)
     self.assertTrue("Cls" in scope_of_func)
     self.assertTrue("cls_attr" not in scope_of_module)
     self.assertTrue("cls_attr" in scope_of_class)
     self.assertTrue("cls_attr" not in scope_of_func)
    def test_nested_comprehension_scope(self) -> None:
        m, scopes = get_scope_metadata_provider("""
            [y for x in iterator for y in x]
            """)
        scope_of_module = scopes[m]
        self.assertIsInstance(scope_of_module, GlobalScope)

        list_comp = ensure_type(
            ensure_type(
                ensure_type(m.body[0], cst.SimpleStatementLine).body[0],
                cst.Expr).value,
            cst.ListComp,
        )
        scope_of_list_comp = scopes[list_comp.elt]
        self.assertIsInstance(scope_of_list_comp, ComprehensionScope)

        self.assertIs(scopes[list_comp], scope_of_module)
        self.assertIs(scopes[list_comp.elt], scope_of_list_comp)

        self.assertIs(scopes[list_comp.for_in], scope_of_module)
        self.assertIs(scopes[list_comp.for_in.iter], scope_of_module)
        self.assertIs(scopes[list_comp.for_in.target], scope_of_list_comp)

        inner_for_in = ensure_type(list_comp.for_in.inner_for_in, cst.CompFor)
        self.assertIs(scopes[inner_for_in], scope_of_list_comp)
        self.assertIs(scopes[inner_for_in.iter], scope_of_list_comp)
        self.assertIs(scopes[inner_for_in.target], scope_of_list_comp)
 def test_multiple_assignments(self) -> None:
     m, scopes = get_scope_metadata_provider("""
             if 1:
                 from a import b as c
             elif 2:
                 from d import e as c
             c()
         """)
     call = ensure_type(
         ensure_type(m.body[1], cst.SimpleStatementLine).body[0],
         cst.Expr).value
     scope = scopes[call]
     self.assertIsInstance(scope, GlobalScope)
     self.assertEqual(
         scope.get_qualified_names_for(call),
         {
             QualifiedName(name="a.b", source=QualifiedNameSource.IMPORT),
             QualifiedName(name="d.e", source=QualifiedNameSource.IMPORT),
         },
     )
     self.assertEqual(
         scope.get_qualified_names_for("c"),
         {
             QualifiedName(name="a.b", source=QualifiedNameSource.IMPORT),
             QualifiedName(name="d.e", source=QualifiedNameSource.IMPORT),
         },
     )
    def test_with_statement(self) -> None:
        m, scopes = get_scope_metadata_provider("""
                import unittest.mock

                with unittest.mock.patch("something") as obj:
                    obj.f1()

                unittest.mock
            """)
        import_ = ensure_type(m.body[0], cst.SimpleStatementLine).body[0]
        assignments = scopes[import_]["unittest"]
        self.assertEqual(len(assignments), 1)
        self.assertEqual(cast(Assignment, list(assignments)[0]).node, import_)
        with_ = ensure_type(m.body[1], cst.With)
        fn_call = with_.items[0].item
        self.assertEqual(
            scopes[fn_call].get_qualified_names_for(fn_call),
            {
                QualifiedName(name="unittest.mock.patch",
                              source=QualifiedNameSource.IMPORT)
            },
        )
        mock = ensure_type(
            ensure_type(m.body[2], cst.SimpleStatementLine).body[0],
            cst.Expr).value
        self.assertEqual(
            scopes[fn_call].get_qualified_names_for(mock),
            {
                QualifiedName(name="unittest.mock",
                              source=QualifiedNameSource.IMPORT)
            },
        )
Beispiel #5
0
    def test_extract_simple(self) -> None:
        # Verify true behavior
        expression = cst.parse_expression("a + b[c], d(e, f * g)")
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(
                        left=m.SaveMatchedNode(m.Name(), "left"))),
                m.Element(m.Call()),
            ]),
        )
        extracted_node = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[0].value,
            cst.BinaryOperation,
        ).left
        self.assertEqual(nodes, {"left": extracted_node})

        # Verify false behavior
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(
                        left=m.SaveMatchedNode(m.Subscript(), "left"))),
                m.Element(m.Call()),
            ]),
        )
        self.assertIsNone(nodes)
    def test_dotted_import_access(self) -> None:
        m, scopes = get_scope_metadata_provider("""
            import a.b.c, x.y
            a.b.c(x.z)
            """)
        scope_of_module = scopes[m]
        first_statement = ensure_type(m.body[1], cst.SimpleStatementLine)
        call = ensure_type(
            ensure_type(first_statement.body[0], cst.Expr).value, cst.Call)
        self.assertTrue("a.b.c" in scope_of_module)
        self.assertTrue("a" in scope_of_module)
        self.assertEqual(scope_of_module.accesses["a"], set())

        a_b_c_assignment = cast(Assignment, list(scope_of_module["a.b.c"])[0])
        a_b_c_access = list(a_b_c_assignment.references)[0]
        self.assertEqual(scope_of_module.accesses["a.b.c"], {a_b_c_access})
        self.assertEqual(a_b_c_access.node, call.func)

        x_assignment = cast(Assignment, list(scope_of_module["x"])[0])
        x_access = list(x_assignment.references)[0]
        self.assertEqual(scope_of_module.accesses["x"], {x_access})
        self.assertEqual(x_access.node,
                         ensure_type(call.args[0].value, cst.Attribute).value)

        self.assertTrue("x.y" in scope_of_module)
        self.assertEqual(list(scope_of_module["x.y"])[0].references, set())
        self.assertEqual(scope_of_module.accesses["x.y"], set())
Beispiel #7
0
    def test_extract_sequence_element(self) -> None:
        # Verify true behavior
        expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)")
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.DoNotCare(),
                m.Element(
                    m.Call(args=[m.SaveMatchedNode(m.ZeroOrMore(), "args")])),
            ]),
        )
        extracted_seq = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[1].value,
            cst.Call).args
        self.assertEqual(nodes, {"args": extracted_seq})

        # Verify false behavior
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.DoNotCare(),
                m.Element(
                    m.Call(args=[
                        m.SaveMatchedNode(m.ZeroOrMore(m.Arg(m.Subscript())),
                                          "args")
                    ])),
            ]),
        )
        self.assertIsNone(nodes)
Beispiel #8
0
    def test_local_scope_shadowing_with_functions(self) -> None:
        m, scopes = get_scope_metadata_provider(
            """
            def f():
                def f():
                    f = ...
            """
        )
        scope_of_module = scopes[m]
        self.assertIsInstance(scope_of_module, GlobalScope)
        self.assertTrue("f" in scope_of_module)

        outer_f = ensure_type(m.body[0], cst.FunctionDef)
        scope_of_outer_f = scopes[outer_f.body.body[0]]
        self.assertIsInstance(scope_of_outer_f, FunctionScope)
        self.assertTrue("f" in scope_of_outer_f)
        out_f_assignment = scope_of_module["f"][0]
        self.assertEqual(cast(Assignment, out_f_assignment).node, outer_f)

        inner_f = ensure_type(outer_f.body.body[0], cst.FunctionDef)
        scope_of_inner_f = scopes[inner_f.body.body[0]]
        self.assertIsInstance(scope_of_inner_f, FunctionScope)
        self.assertTrue("f" in scope_of_inner_f)
        inner_f_assignment = scope_of_outer_f["f"][0]
        self.assertEqual(cast(Assignment, inner_f_assignment).node, inner_f)
    def _extract_static_bool(cls, node: cst.BaseExpression) -> Optional[bool]:
        if m.matches(node, m.Call()):
            # cannot reason about function calls
            return None
        if m.matches(node, m.UnaryOperation(operator=m.Not())):
            sub_value = cls._extract_static_bool(
                cst.ensure_type(node, cst.UnaryOperation).expression)
            if sub_value is None:
                return None
            return not sub_value

        if m.matches(node, m.Name("True")):
            return True

        if m.matches(node, m.Name("False")):
            return False

        if m.matches(node, m.BooleanOperation()):
            node = cst.ensure_type(node, cst.BooleanOperation)
            left_value = cls._extract_static_bool(node.left)
            right_value = cls._extract_static_bool(node.right)
            if m.matches(node.operator, m.Or()):
                if right_value is True or left_value is True:
                    return True

            if m.matches(node.operator, m.And()):
                if right_value is False or left_value is False:
                    return False

        return None
Beispiel #10
0
def _get_clean_type(typeobj: object) -> str:
    """
    Given a type object as returned by dataclasses, sanitize it and convert it
    to a type string that is appropriate for our codegen below.
    """

    # First, get the type as a parseable expression.
    typestr = repr(typeobj)
    if typestr.startswith("<class '") and typestr.endswith("'>"):
        typestr = typestr[8:-2]

    # Now, parse the expression with LibCST.
    cleanser = CleanseFullTypeNames()
    typecst = parse_expression(typestr)
    typecst = typecst.visit(cleanser)
    clean_type: Optional[cst.CSTNode] = None

    # Now, convert the type to allow for DoNotCareSentinel values.
    if isinstance(typecst, cst.Subscript):
        if typecst.value.deep_equals(cst.Name("Union")):
            # We can modify this as-is to add our type
            clean_type = typecst.with_changes(
                slice=[*typecst.slice, _get_do_not_care()]
            )
        elif typecst.value.deep_equals(cst.Name("Literal")):
            clean_type = _get_wrapped_union_type(typecst, _get_do_not_care())
        elif typecst.value.deep_equals(cst.Name("Sequence")):
            clean_type = _get_wrapped_union_type(typecst, _get_do_not_care())

    elif isinstance(typecst, (cst.Name, cst.SimpleString)):
        clean_type = _get_wrapped_union_type(typecst, _get_do_not_care())

    # Now, clean up the outputted type and return the code it generates. If
    # for some reason we encounter a new node type, raise so we can triage.
    if clean_type is None:
        raise Exception(f"Don't support {typecst}")
    else:
        # First, add DoNotCareSentinel to all sequences, so that a sequence
        # can be defined partially with explicit DoNotCare() values for some
        # slots.
        clean_type = ensure_type(
            clean_type.visit(AddDoNotCareToSequences()), cst.CSTNode
        )
        # Now, double-quote any types we parsed and repr'd, for consistency.
        clean_type = ensure_type(clean_type.visit(DoubleQuoteStrings()), cst.CSTNode)
        # Now, insert OneOf/AllOf and MatchIfTrue into unions so we can typecheck their usage.
        # This allows us to put OneOf[SomeType] or MatchIfTrue[cst.SomeType] into any
        # spot that we would have originally allowed a SomeType.
        clean_type = ensure_type(
            clean_type.visit(AddLogicAndLambdaMatcherToUnions()), cst.CSTNode
        )
        # Now, insert AtMostN and AtLeastN into sequence unions, so we can typecheck
        # them. This relies on the previous OneOf/AllOf insertion to ensure that all
        # sequences we care about are Sequence[Union[<x>]].
        clean_type = ensure_type(
            clean_type.visit(AddWildcardsToSequenceUnions()), cst.CSTNode
        )
        # Finally, generate the code given a default Module so we can spit it out.
        return cst.Module(body=()).code_for_node(clean_type)
Beispiel #11
0
    def test_accesses(self) -> None:
        m, scopes = get_scope_metadata_provider("""
            foo = 'toplevel'
            fn1(foo)
            fn2(foo)
            def fn_def():
                foo = 'shadow'
                fn3(foo)
            """)
        scope_of_module = scopes[m]
        self.assertIsInstance(scope_of_module, GlobalScope)
        global_foo_assignments = list(scope_of_module["foo"])
        self.assertEqual(len(global_foo_assignments), 1)
        foo_assignment = global_foo_assignments[0]
        self.assertEqual(len(foo_assignment.references), 2)
        fn1_call_arg = ensure_type(
            ensure_type(
                ensure_type(m.body[1], cst.SimpleStatementLine).body[0],
                cst.Expr).value,
            cst.Call,
        ).args[0]

        fn2_call_arg = ensure_type(
            ensure_type(
                ensure_type(m.body[2], cst.SimpleStatementLine).body[0],
                cst.Expr).value,
            cst.Call,
        ).args[0]
        self.assertEqual(
            {access.node
             for access in foo_assignment.references},
            {fn1_call_arg.value, fn2_call_arg.value},
        )
        func_body = ensure_type(m.body[3], cst.FunctionDef).body
        func_foo_statement = func_body.body[0]
        scope_of_func_statement = scopes[func_foo_statement]
        self.assertIsInstance(scope_of_func_statement, FunctionScope)
        func_foo_assignments = scope_of_func_statement["foo"]
        self.assertEqual(len(func_foo_assignments), 1)
        foo_assignment = list(func_foo_assignments)[0]
        self.assertEqual(len(foo_assignment.references), 1)
        fn3_call_arg = ensure_type(
            ensure_type(
                ensure_type(func_body.body[1],
                            cst.SimpleStatementLine).body[0],
                cst.Expr,
            ).value,
            cst.Call,
        ).args[0]
        self.assertEqual({access.node
                          for access in foo_assignment.references},
                         {fn3_call_arg.value})

        wrapper = MetadataWrapper(cst.parse_module("from a import b\n"))
        wrapper.visit(DependentVisitor())

        wrapper = MetadataWrapper(
            cst.parse_module("def a():\n    from b import c\n\n"))
        wrapper.visit(DependentVisitor())
Beispiel #12
0
 def test_keyword_arg_in_call(self) -> None:
     m, scopes = get_scope_metadata_provider("call(arg=val)")
     call = ensure_type(
         ensure_type(m.body[0], cst.SimpleStatementLine).body[0],
         cst.Expr).value
     scope = scopes[call]
     self.assertIsInstance(scope, GlobalScope)
     self.assertEqual(len(scope["arg"]), 0)  # no assignment should exist
Beispiel #13
0
def clean_generated_code(code: str) -> str:
    """
    Generalized sanity clean-up for all codegen so we can fix issues such as
    Union[SingleType]. The transforms found here are strictly for form and
    do not affect functionality.
    """
    module = parse_module(code)
    module = ensure_type(module.visit(SimplifyUnionsTransformer()), cst.Module)
    module = ensure_type(module.visit(DoubleQuoteForwardRefsTransformer()), cst.Module)
    return module.code
Beispiel #14
0
 def test_extractall_simple(self) -> None:
     expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)")
     matches = extractall(expression, m.Arg(m.SaveMatchedNode(~m.Name(), "expr")))
     extracted_args = cst.ensure_type(
         cst.ensure_type(expression, cst.Tuple).elements[1].value, cst.Call
     ).args
     self.assertEqual(
         matches,
         [{"expr": extracted_args[1].value}, {"expr": extracted_args[2].value}],
     )
Beispiel #15
0
def _test_simple_class_helper(test: UnitTest,
                              wrapper: MetadataWrapper) -> None:
    types = wrapper.resolve(TypeInferenceProvider)
    m = wrapper.module
    assign = cst.ensure_type(
        cst.ensure_type(
            cst.ensure_type(
                cst.ensure_type(m.body[1].body, cst.IndentedBlock).body[0],
                cst.FunctionDef,
            ).body.body[0],
            cst.SimpleStatementLine,
        ).body[0],
        cst.AnnAssign,
    )
    self_number_attr = cst.ensure_type(assign.target, cst.Attribute)
    test.assertEqual(types[self_number_attr], "int")

    value = assign.value
    if value:
        test.assertEqual(types[value], "int")

    # self
    test.assertEqual(types[self_number_attr.value], "simple_class.Item")
    collector_assign = cst.ensure_type(
        cst.ensure_type(m.body[3], cst.SimpleStatementLine).body[0],
        cst.Assign)
    collector = collector_assign.targets[0].target
    test.assertEqual(types[collector], "simple_class.ItemCollector")
    items_assign = cst.ensure_type(
        cst.ensure_type(m.body[4], cst.SimpleStatementLine).body[0],
        cst.AnnAssign)
    items = items_assign.target
    test.assertEqual(types[items], "typing.Sequence[simple_class.Item]")
Beispiel #16
0
 def _replace_nested(
     node: cst.CSTNode,
     extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]],
 ) -> cst.CSTNode:
     return cst.ensure_type(node, cst.Call).with_changes(args=[
         cst.Arg(
             cst.Name(value=cst.ensure_type(
                 cst.ensure_type(extraction["inner"], cst.Call).func,
                 cst.Name,
             ).value + "_immediate"))
     ])
Beispiel #17
0
 def test_parse_import_simple(self) -> None:
     node = util.parse_import("import a")
     self.assertEqual(
         cst.ensure_type(
             cst.ensure_type(
                 cst.ensure_type(node, cst.SimpleStatementLine).body[0],
                 cst.Import,
             ).names[0],
             cst.ImportAlias,
         ).name.value,
         "a",
     )
Beispiel #18
0
    def test_nested_qualified_names(self) -> None:
        m, names = get_qualified_name_metadata_provider(
            """
            class A:
                def f1(self):
                    def f2():
                        pass
                    f2()

                def f3(self):
                    class B():
                        ...
                    B()
            def f4():
                def f5():
                    class C:
                        pass
                    C()
                f5()
            """
        )

        cls_a = ensure_type(m.body[0], cst.ClassDef)
        self.assertEqual(names[cls_a], {QualifiedName("A", QualifiedNameSource.LOCAL)})
        func_f1 = ensure_type(cls_a.body.body[0], cst.FunctionDef)
        self.assertEqual(
            names[func_f1], {QualifiedName("A.f1", QualifiedNameSource.LOCAL)}
        )
        func_f2_call = ensure_type(
            ensure_type(func_f1.body.body[1], cst.SimpleStatementLine).body[0], cst.Expr
        ).value
        self.assertEqual(
            names[func_f2_call],
            {QualifiedName("A.f1.<locals>.f2", QualifiedNameSource.LOCAL)},
        )
        func_f3 = ensure_type(cls_a.body.body[1], cst.FunctionDef)
        self.assertEqual(
            names[func_f3], {QualifiedName("A.f3", QualifiedNameSource.LOCAL)}
        )
        call_b = ensure_type(
            ensure_type(func_f3.body.body[1], cst.SimpleStatementLine).body[0], cst.Expr
        ).value
        self.assertEqual(
            names[call_b], {QualifiedName("A.f3.<locals>.B", QualifiedNameSource.LOCAL)}
        )
        func_f4 = ensure_type(m.body[1], cst.FunctionDef)
        self.assertEqual(
            names[func_f4], {QualifiedName("f4", QualifiedNameSource.LOCAL)}
        )
        func_f5 = ensure_type(func_f4.body.body[0], cst.FunctionDef)
        self.assertEqual(
            names[func_f5], {QualifiedName("f4.<locals>.f5", QualifiedNameSource.LOCAL)}
        )
        cls_c = func_f5.body.body[0]
        self.assertEqual(
            names[cls_c],
            {QualifiedName("f4.<locals>.f5.<locals>.C", QualifiedNameSource.LOCAL)},
        )
Beispiel #19
0
 def _add_one_to_arg(
     node: cst.CSTNode,
     extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]],
 ) -> cst.CSTNode:
     return node.deep_replace(
         # This can be either a node or a sequence, pyre doesn't know.
         cst.ensure_type(extraction["arg"], cst.CSTNode),
         # Grab the arg and add one to its value.
         cst.Integer(
             str(
                 int(cst.ensure_type(extraction["arg"], cst.Integer).value)
                 + 1)),
     )
Beispiel #20
0
 def _make_fixture(
         self,
         code: str) -> Tuple[cst.BaseExpression, meta.MetadataWrapper]:
     module = cst.parse_module(dedent(code))
     wrapper = cst.MetadataWrapper(module)
     return (
         cst.ensure_type(
             cst.ensure_type(wrapper.module.body[0],
                             cst.SimpleStatementLine).body[0],
             cst.Expr,
         ).value,
         wrapper,
     )
Beispiel #21
0
    def test_deep_replace_simple(self) -> None:
        old_code = """
            pass
        """
        new_code = """
            break
        """

        module = cst.parse_module(dedent(old_code))
        pass_stmt = cst.ensure_type(module.body[0],
                                    cst.SimpleStatementLine).body[0]
        new_module = cst.ensure_type(
            module.deep_replace(pass_stmt, cst.Break()), cst.Module)
        self.assertEqual(new_module.code, dedent(new_code))
    def visit_Call(self, node: cst.Call) -> None:
        if m.matches(
                node,
                m.Call(
                    func=m.Name("list") | m.Name("set") | m.Name("dict"),
                    args=[m.Arg(value=m.GeneratorExp() | m.ListComp())],
                ),
        ):
            call_name = cst.ensure_type(node.func, cst.Name).value

            if m.matches(node.args[0].value, m.GeneratorExp()):
                exp = cst.ensure_type(node.args[0].value, cst.GeneratorExp)
                message_formatter = UNNECESSARY_GENERATOR
            else:
                exp = cst.ensure_type(node.args[0].value, cst.ListComp)
                message_formatter = UNNECESSARY_LIST_COMPREHENSION

            replacement = None
            if call_name == "list":
                replacement = node.deep_replace(
                    node, cst.ListComp(elt=exp.elt, for_in=exp.for_in))
            elif call_name == "set":
                replacement = node.deep_replace(
                    node, cst.SetComp(elt=exp.elt, for_in=exp.for_in))
            elif call_name == "dict":
                elt = exp.elt
                key = None
                value = None
                if m.matches(elt, m.Tuple(m.DoNotCare(), m.DoNotCare())):
                    elt = cst.ensure_type(elt, cst.Tuple)
                    key = elt.elements[0].value
                    value = elt.elements[1].value
                elif m.matches(elt, m.List(m.DoNotCare(), m.DoNotCare())):
                    elt = cst.ensure_type(elt, cst.List)
                    key = elt.elements[0].value
                    value = elt.elements[1].value
                else:
                    # Unrecoginized form
                    return

                replacement = node.deep_replace(
                    node,
                    # pyre-fixme[6]: Expected `BaseAssignTargetExpression` for 1st
                    #  param but got `BaseExpression`.
                    cst.DictComp(key=key, value=value, for_in=exp.for_in),
                )

            self.report(node,
                        message_formatter.format(func=call_name),
                        replacement=replacement)
 def test_with_asname(self) -> None:
     m, scopes = get_scope_metadata_provider("""
         with open(file_name) as f:
             ...
         """)
     scope_of_module = scopes[m]
     self.assertIsInstance(scope_of_module, GlobalScope)
     self.assertTrue("f" in scope_of_module)
     self.assertEqual(
         cast(Assignment, scope_of_module["f"][0]).node,
         ensure_type(
             ensure_type(m.body[0], cst.With).items[0].asname,
             cst.AsName).name,
     )
    def test_annotation_access(self) -> None:
        m, scopes = get_scope_metadata_provider("""
                from typing import Literal, TypeVar
                from a import A, B, C, D, E, F
                def x(a: A):
                    pass
                def y(b: "B"):
                    pass
                def z(c: Literal["C"]):
                    pass
                DType = TypeVar("DType", bound=D)
                EType = TypeVar("EType", bound="E")
                FType = TypeVar("F")
            """)
        imp = ensure_type(
            ensure_type(m.body[1], cst.SimpleStatementLine).body[0],
            cst.ImportFrom)
        scope = scopes[imp]

        assignment = list(scope["A"])[0]
        self.assertIsInstance(assignment, Assignment)
        self.assertEqual(len(assignment.references), 1)
        references = list(assignment.references)
        self.assertTrue(references[0].is_annotation)

        assignment = list(scope["B"])[0]
        self.assertIsInstance(assignment, Assignment)
        self.assertEqual(len(assignment.references), 1)
        references = list(assignment.references)
        self.assertTrue(references[0].is_annotation)

        assignment = list(scope["C"])[0]
        self.assertIsInstance(assignment, Assignment)
        self.assertEqual(len(assignment.references), 0)

        assignment = list(scope["D"])[0]
        self.assertIsInstance(assignment, Assignment)
        self.assertEqual(len(assignment.references), 1)
        references = list(assignment.references)
        self.assertTrue(references[0].is_annotation)

        assignment = list(scope["E"])[0]
        self.assertIsInstance(assignment, Assignment)
        self.assertEqual(len(assignment.references), 1)
        references = list(assignment.references)
        self.assertTrue(references[0].is_annotation)

        assignment = list(scope["F"])[0]
        self.assertIsInstance(assignment, Assignment)
        self.assertEqual(len(assignment.references), 0)
Beispiel #25
0
 def test_extract_sequence(self) -> None:
     expression = cst.parse_expression("a + b[c], d(e, f * g, h.i.j)")
     nodes = m.extract(
         expression,
         m.Tuple(elements=[
             m.DoNotCare(),
             m.Element(
                 m.Call(args=m.SaveMatchedNode([m.ZeroOrMore()], "args"))),
         ]),
     )
     extracted_seq = cst.ensure_type(
         cst.ensure_type(expression, cst.Tuple).elements[1].value,
         cst.Call).args
     self.assertEqual(nodes, {"args": extracted_seq})
Beispiel #26
0
    def leave_With(
        self, original_node: cst.With, updated_node: cst.With
    ) -> Union[cst.BaseStatement, cst.RemovalSentinel]:

        candidate_with: cst.With = original_node
        compound_items: List[cst.WithItem] = []
        final_body: cst.BaseSuite = candidate_with.body

        while True:
            # There is no way to meaningfully represent comments inside
            # multi-line `with` statements due to how Python grammar is
            # written, so we do not try to transform such `with` statements
            # lest we lose something important in the comments.
            if has_leading_comment(candidate_with):
                break

            if has_inline_comment(candidate_with.body):
                break

            # There is no meaningful way `async with` can be merged into
            # the compound `with` statement.
            if candidate_with.asynchronous:
                break

            compound_items.extend(candidate_with.items)
            final_body = candidate_with.body

            if not isinstance(final_body.body[0], cst.With):
                break

            if len(final_body.body) > 1:
                break

            candidate_with = cst.ensure_type(candidate_with.body.body[0],
                                             cst.With)

        if len(compound_items) <= 1:
            return original_node

        final_body = cst.ensure_type(final_body, cst.IndentedBlock)
        topmost_body = cst.ensure_type(original_node.body, cst.IndentedBlock)

        if has_footer_comment(
                topmost_body) and not has_footer_comment(final_body):
            final_body = final_body.with_changes(footer=(*final_body.footer,
                                                         *topmost_body.footer))

        return updated_node.with_changes(body=final_body, items=compound_items)
Beispiel #27
0
    def test_func_param_scope(self) -> None:
        m, scopes = get_scope_metadata_provider("""
            @decorator
            def f(x: T=1, *vararg, y: T=2, z, **kwarg) -> RET:
                pass
            """)
        scope_of_module = scopes[m]
        self.assertIsInstance(scope_of_module, GlobalScope)
        self.assertTrue("f" in scope_of_module)

        f = ensure_type(m.body[0], cst.FunctionDef)
        scope_of_f = scopes[f.body.body[0]]
        self.assertIsInstance(scope_of_f, FunctionScope)

        decorator = f.decorators[0]
        x = f.params.params[0]
        xT = ensure_type(x.annotation, cst.Annotation)
        one = ensure_type(x.default, cst.BaseExpression)
        vararg = ensure_type(f.params.star_arg, cst.Param)
        y = f.params.kwonly_params[0]
        yT = ensure_type(y.annotation, cst.Annotation)
        two = ensure_type(y.default, cst.BaseExpression)
        z = f.params.kwonly_params[1]
        kwarg = ensure_type(f.params.star_kwarg, cst.Param)
        ret = ensure_type(f.returns, cst.Annotation).annotation

        self.assertEqual(scopes[decorator], scope_of_module)
        self.assertEqual(scopes[x], scope_of_f)
        self.assertEqual(scopes[xT], scope_of_module)
        self.assertEqual(scopes[one], scope_of_module)
        self.assertEqual(scopes[vararg], scope_of_f)
        self.assertEqual(scopes[y], scope_of_f)
        self.assertEqual(scopes[yT], scope_of_module)
        self.assertEqual(scopes[z], scope_of_f)
        self.assertEqual(scopes[two], scope_of_module)
        self.assertEqual(scopes[kwarg], scope_of_f)
        self.assertEqual(scopes[ret], scope_of_module)

        self.assertTrue("x" not in scope_of_module)
        self.assertTrue("x" in scope_of_f)
        self.assertTrue("vararg" not in scope_of_module)
        self.assertTrue("vararg" in scope_of_f)
        self.assertTrue("y" not in scope_of_module)
        self.assertTrue("y" in scope_of_f)
        self.assertTrue("z" not in scope_of_module)
        self.assertTrue("z" in scope_of_f)
        self.assertTrue("kwarg" not in scope_of_module)
        self.assertTrue("kwarg" in scope_of_f)

        self.assertEqual(cast(Assignment, list(scope_of_f["x"])[0]).node, x)
        self.assertEqual(
            cast(Assignment,
                 list(scope_of_f["vararg"])[0]).node, vararg)
        self.assertEqual(cast(Assignment, list(scope_of_f["y"])[0]).node, y)
        self.assertEqual(cast(Assignment, list(scope_of_f["z"])[0]).node, z)
        self.assertEqual(
            cast(Assignment,
                 list(scope_of_f["kwarg"])[0]).node, kwarg)
Beispiel #28
0
    def test_extract_metadata(self) -> None:
        # Verify true behavior
        module = cst.parse_module("a + b[c], d(e, f * g)")
        wrapper = cst.MetadataWrapper(module)
        expression = cst.ensure_type(
            cst.ensure_type(wrapper.module.body[0],
                            cst.SimpleStatementLine).body[0],
            cst.Expr,
        ).value

        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(left=m.Name(metadata=m.SaveMatchedNode(
                        m.MatchMetadata(
                            meta.PositionProvider,
                            self._make_coderange((1, 0), (1, 1)),
                        ),
                        "left",
                    )))),
                m.Element(m.Call()),
            ]),
            metadata_resolver=wrapper,
        )
        extracted_node = cst.ensure_type(
            cst.ensure_type(expression, cst.Tuple).elements[0].value,
            cst.BinaryOperation,
        ).left
        self.assertEqual(nodes, {"left": extracted_node})

        # Verify false behavior
        nodes = m.extract(
            expression,
            m.Tuple(elements=[
                m.Element(
                    m.BinaryOperation(left=m.Name(metadata=m.SaveMatchedNode(
                        m.MatchMetadata(
                            meta.PositionProvider,
                            self._make_coderange((1, 0), (1, 2)),
                        ),
                        "left",
                    )))),
                m.Element(m.Call()),
            ]),
            metadata_resolver=wrapper,
        )
        self.assertIsNone(nodes)
Beispiel #29
0
    def test_imoprt_from(self) -> None:
        m, scopes = get_scope_metadata_provider(
            """
            from foo.bar import a, b as b_renamed
            from . import c
            from .foo import d
            """
        )
        scope_of_module = scopes[m]
        for idx, in_scope in [(0, "a"), (0, "b_renamed"), (1, "c"), (2, "d")]:
            self.assertEqual(
                len(scope_of_module[in_scope]), 1, f"{in_scope} should be in scope."
            )
            import_assignment = cast(Assignment, scope_of_module[in_scope][0])
            self.assertEqual(
                import_assignment.name,
                in_scope,
                f"The name of Assignment {import_assignment.name} should equal to {in_scope}.",
            )
            import_node = ensure_type(m.body[idx], cst.SimpleStatementLine).body[0]
            self.assertEqual(
                import_assignment.node,
                import_node,
                f"The node of Assignment {import_assignment.node} should equal to {import_node}",
            )

        for not_in_scope in ["foo", "bar", "foo.bar", "b"]:
            self.assertEqual(
                len(scope_of_module[not_in_scope]),
                0,
                f"{not_in_scope} should not be in scope.",
            )
Beispiel #30
0
    def test_import(self) -> None:
        m, scopes = get_scope_metadata_provider(
            """
            import foo.bar
            import fizz.buzz as fizzbuzz
            import a.b.c
            import d.e.f as g
            """
        )
        scope_of_module = scopes[m]
        for idx, in_scope in enumerate(["foo", "fizzbuzz", "a", "g"]):
            self.assertEqual(
                len(scope_of_module[in_scope]), 1, f"{in_scope} should be in scope."
            )

            assignment = cast(Assignment, scope_of_module[in_scope][0])
            self.assertEqual(
                assignment.name,
                in_scope,
                f"Assignment name {assignment.name} should equal to {in_scope}.",
            )
            import_node = ensure_type(m.body[idx], cst.SimpleStatementLine).body[0]
            self.assertEqual(
                assignment.node,
                import_node,
                f"The node of Assignment {assignment.node} should equal to {import_node}",
            )