Esempio n. 1
0
    def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
        new_bases: List[cst.Arg] = []
        namedtuple_base: Optional[cst.Arg] = None

        # Need to examine the original node's bases since they are directly tied to import metadata
        for base_class in original_node.bases:
            # Compare the base class's qualified name against the expected typing.NamedTuple
            if not QualifiedNameProvider.has_name(self, base_class.value, self.qualified_namedtuple):
                # Keep all bases that are not of type typing.NamedTuple
                new_bases.append(base_class)
            else:
                namedtuple_base = base_class

        # We still want to return the updated node in case some of its children have been modified
        if namedtuple_base is None:
            return updated_node

        AddImportsVisitor.add_needed_import(self.context, "attr", "dataclass")
        AddImportsVisitor.add_needed_import(self.context, "pydantic.dataclasses", "dataclass")
        RemoveImportsVisitor.remove_unused_import_by_node(self.context, namedtuple_base.value)

        call = cst.ensure_type(
            cst.parse_expression("dataclass(frozen=False)", config=self.module.config_for_parsing),
            cst.Call,
        )
        return updated_node.with_changes(
            lpar=cst.MaybeSentinel.DEFAULT,
            rpar=cst.MaybeSentinel.DEFAULT,
            bases=new_bases,
            decorators=[*original_node.decorators, cst.Decorator(decorator=call)],
        )
Esempio n. 2
0
 def function_def(
     self,
     name: str,
     type: typing.Literal["function", "classmethod", "method"],
     indent=0,
     overload=False,
 ) -> cst.FunctionDef:
     decorators: typing.List[cst.Decorator] = []
     if overload:
         decorators.append(cst.Decorator(cst.Name("overload")))
     if type == "classmethod":
         decorators.append(cst.Decorator(cst.Name("classmethod")))
     return cst.FunctionDef(
         cst.Name(name), self.parameters(type),
         cst.IndentedBlock(
             [cst.SimpleStatementLine([s]) for s in self.body(indent)]),
         decorators, self.return_type_annotation)
Esempio n. 3
0
    def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
        (namedtuple_base, new_bases) = self.partition_bases(original_node.bases)
        if namedtuple_base is not None:
            call = ensure_type(parse_expression("dataclass(frozen=True)"), cst.Call)

            replacement = original_node.with_changes(
                lpar=MaybeSentinel.DEFAULT,
                rpar=MaybeSentinel.DEFAULT,
                bases=new_bases,
                decorators=list(original_node.decorators) + [cst.Decorator(decorator=call)],
            )
            self.report(original_node, replacement=replacement)
Esempio n. 4
0
class ExpressionTest(UnitTest):
    @data_provider((
        ("a string", "a string"),
        (cst.Name("a_name"), "a_name"),
        (cst.parse_expression("a.b.c"), "a.b.c"),
        (cst.parse_expression("a.b()"), "a.b"),
        (cst.parse_expression("a.b.c[i]"), "a.b.c"),
        (cst.parse_statement("def fun():  pass"), "fun"),
        (cst.parse_statement("class cls:  pass"), "cls"),
        (
            cst.Decorator(
                ensure_type(cst.parse_expression("a.b.c.d"), cst.Attribute)),
            "a.b.c.d",
        ),
        (cst.parse_statement("(a.b()).c()"),
         None),  # not a supported Node type
    ))
    def test_get_full_name_for_expression(
        self,
        input: Union[str, cst.CSTNode],
        output: Optional[str],
    ) -> None:
        self.assertEqual(get_full_name_for_node(input), output)
        if output is None:
            with self.assertRaises(Exception):
                get_full_name_for_node_or_raise(input)
        else:
            self.assertEqual(get_full_name_for_node_or_raise(input), output)

    def test_simplestring_evaluated_value(self) -> None:
        raw_string = '"a string."'
        node = ensure_type(cst.parse_expression(raw_string), cst.SimpleString)
        self.assertEqual(node.value, raw_string)
        self.assertEqual(node.evaluated_value, literal_eval(raw_string))

    def test_integer_evaluated_value(self) -> None:
        raw_value = "5"
        node = ensure_type(cst.parse_expression(raw_value), cst.Integer)
        self.assertEqual(node.value, raw_value)
        self.assertEqual(node.evaluated_value, literal_eval(raw_value))

    def test_float_evaluated_value(self) -> None:
        raw_value = "5.5"
        node = ensure_type(cst.parse_expression(raw_value), cst.Float)
        self.assertEqual(node.value, raw_value)
        self.assertEqual(node.evaluated_value, literal_eval(raw_value))

    def test_complex_evaluated_value(self) -> None:
        raw_value = "5j"
        node = ensure_type(cst.parse_expression(raw_value), cst.Imaginary)
        self.assertEqual(node.value, raw_value)
        self.assertEqual(node.evaluated_value, literal_eval(raw_value))
 def test_decorators(self) -> None:
     # Test that we can special-case decorators when needed.
     statement = parse_template_statement(
         "@{decorator}\ndef foo(): pass\n", decorator=cst.Name("bar"),
     )
     self.assertEqual(
         self.code(statement), "@bar\ndef foo(): pass\n",
     )
     statement = parse_template_statement(
         "@{decorator}\ndef foo(): pass\n", decorator=cst.Decorator(cst.Name("bar")),
     )
     self.assertEqual(
         self.code(statement), "@bar\ndef foo(): pass\n",
     )
Esempio n. 6
0
class ClassDefParserTest(CSTNodeTest):
    @data_provider((
        # Simple classdef
        # pyre-fixme[6]: Incompatible parameter type
        {
            "node":
            cst.ClassDef(cst.Name("Foo"),
                         cst.SimpleStatementSuite((cst.Pass(), ))),
            "code":
            "class Foo: pass\n",
        },
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(): pass\n",
        },
        # Positional arguments render test
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                bases=(cst.Arg(cst.Name("obj")), ),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(obj): pass\n",
        },
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                bases=(
                    cst.Arg(
                        cst.Name("Bar"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        cst.Name("Baz"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(cst.Name("object")),
                ),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(Bar, Baz, object): pass\n",
        },
        # Keyword arguments render test
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                keywords=(cst.Arg(
                    keyword=cst.Name("metaclass"),
                    equal=cst.AssignEqual(),
                    value=cst.Name("Bar"),
                ), ),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(metaclass = Bar): pass\n",
        },
        # Iterator expansion render test
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                bases=(cst.Arg(star="*", value=cst.Name("one")), ),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(*one): pass\n",
        },
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                bases=(
                    cst.Arg(
                        star="*",
                        value=cst.Name("one"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="*",
                        value=cst.Name("two"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(star="*", value=cst.Name("three")),
                ),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(*one, *two, *three): pass\n",
        },
        # Dictionary expansion render test
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                keywords=(cst.Arg(star="**", value=cst.Name("one")), ),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(**one): pass\n",
        },
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                keywords=(
                    cst.Arg(
                        star="**",
                        value=cst.Name("one"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(
                        star="**",
                        value=cst.Name("two"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.Arg(star="**", value=cst.Name("three")),
                ),
                rpar=cst.RightParen(),
            ),
            "code":
            "class Foo(**one, **two, **three): pass\n",
        },
        # Decorator render tests
        {
            "node":
            cst.ClassDef(
                cst.Name("Foo"),
                cst.SimpleStatementSuite((cst.Pass(), )),
                decorators=(cst.Decorator(cst.Name("foo")), ),
                lpar=cst.LeftParen(),
                rpar=cst.RightParen(),
            ),
            "code":
            "@foo\nclass Foo(): pass\n",
            "expected_position":
            CodeRange((2, 0), (2, 17)),
        },
        {
            "node":
            cst.ClassDef(
                leading_lines=(
                    cst.EmptyLine(),
                    cst.EmptyLine(comment=cst.Comment("# leading comment 1")),
                ),
                decorators=(
                    cst.Decorator(cst.Name("foo"), leading_lines=()),
                    cst.Decorator(
                        cst.Name("bar"),
                        leading_lines=(cst.EmptyLine(
                            comment=cst.Comment("# leading comment 2")), ),
                    ),
                    cst.Decorator(
                        cst.Name("baz"),
                        leading_lines=(cst.EmptyLine(
                            comment=cst.Comment("# leading comment 3")), ),
                    ),
                ),
                lines_after_decorators=(cst.EmptyLine(
                    comment=cst.Comment("# class comment")), ),
                name=cst.Name("Foo"),
                body=cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(),
                rpar=cst.RightParen(),
            ),
            "code":
            "\n# leading comment 1\n@foo\n# leading comment 2\n@bar\n# leading comment 3\n@baz\n# class comment\nclass Foo(): pass\n",
            "expected_position":
            CodeRange((9, 0), (9, 17)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs, parser=parse_statement)