def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: new_bases: List[cst.Arg] = [] namedtuple_base: Optional[cst.Arg] = None # Need to examine the original node's bases since they are directly tied to import metadata for base_class in original_node.bases: # Compare the base class's qualified name against the expected typing.NamedTuple if not QualifiedNameProvider.has_name(self, base_class.value, self.qualified_namedtuple): # Keep all bases that are not of type typing.NamedTuple new_bases.append(base_class) else: namedtuple_base = base_class # We still want to return the updated node in case some of its children have been modified if namedtuple_base is None: return updated_node AddImportsVisitor.add_needed_import(self.context, "attr", "dataclass") AddImportsVisitor.add_needed_import(self.context, "pydantic.dataclasses", "dataclass") RemoveImportsVisitor.remove_unused_import_by_node(self.context, namedtuple_base.value) call = cst.ensure_type( cst.parse_expression("dataclass(frozen=False)", config=self.module.config_for_parsing), cst.Call, ) return updated_node.with_changes( lpar=cst.MaybeSentinel.DEFAULT, rpar=cst.MaybeSentinel.DEFAULT, bases=new_bases, decorators=[*original_node.decorators, cst.Decorator(decorator=call)], )
def function_def( self, name: str, type: typing.Literal["function", "classmethod", "method"], indent=0, overload=False, ) -> cst.FunctionDef: decorators: typing.List[cst.Decorator] = [] if overload: decorators.append(cst.Decorator(cst.Name("overload"))) if type == "classmethod": decorators.append(cst.Decorator(cst.Name("classmethod"))) return cst.FunctionDef( cst.Name(name), self.parameters(type), cst.IndentedBlock( [cst.SimpleStatementLine([s]) for s in self.body(indent)]), decorators, self.return_type_annotation)
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: (namedtuple_base, new_bases) = self.partition_bases(original_node.bases) if namedtuple_base is not None: call = ensure_type(parse_expression("dataclass(frozen=True)"), cst.Call) replacement = original_node.with_changes( lpar=MaybeSentinel.DEFAULT, rpar=MaybeSentinel.DEFAULT, bases=new_bases, decorators=list(original_node.decorators) + [cst.Decorator(decorator=call)], ) self.report(original_node, replacement=replacement)
class ExpressionTest(UnitTest): @data_provider(( ("a string", "a string"), (cst.Name("a_name"), "a_name"), (cst.parse_expression("a.b.c"), "a.b.c"), (cst.parse_expression("a.b()"), "a.b"), (cst.parse_expression("a.b.c[i]"), "a.b.c"), (cst.parse_statement("def fun(): pass"), "fun"), (cst.parse_statement("class cls: pass"), "cls"), ( cst.Decorator( ensure_type(cst.parse_expression("a.b.c.d"), cst.Attribute)), "a.b.c.d", ), (cst.parse_statement("(a.b()).c()"), None), # not a supported Node type )) def test_get_full_name_for_expression( self, input: Union[str, cst.CSTNode], output: Optional[str], ) -> None: self.assertEqual(get_full_name_for_node(input), output) if output is None: with self.assertRaises(Exception): get_full_name_for_node_or_raise(input) else: self.assertEqual(get_full_name_for_node_or_raise(input), output) def test_simplestring_evaluated_value(self) -> None: raw_string = '"a string."' node = ensure_type(cst.parse_expression(raw_string), cst.SimpleString) self.assertEqual(node.value, raw_string) self.assertEqual(node.evaluated_value, literal_eval(raw_string)) def test_integer_evaluated_value(self) -> None: raw_value = "5" node = ensure_type(cst.parse_expression(raw_value), cst.Integer) self.assertEqual(node.value, raw_value) self.assertEqual(node.evaluated_value, literal_eval(raw_value)) def test_float_evaluated_value(self) -> None: raw_value = "5.5" node = ensure_type(cst.parse_expression(raw_value), cst.Float) self.assertEqual(node.value, raw_value) self.assertEqual(node.evaluated_value, literal_eval(raw_value)) def test_complex_evaluated_value(self) -> None: raw_value = "5j" node = ensure_type(cst.parse_expression(raw_value), cst.Imaginary) self.assertEqual(node.value, raw_value) self.assertEqual(node.evaluated_value, literal_eval(raw_value))
def test_decorators(self) -> None: # Test that we can special-case decorators when needed. statement = parse_template_statement( "@{decorator}\ndef foo(): pass\n", decorator=cst.Name("bar"), ) self.assertEqual( self.code(statement), "@bar\ndef foo(): pass\n", ) statement = parse_template_statement( "@{decorator}\ndef foo(): pass\n", decorator=cst.Decorator(cst.Name("bar")), ) self.assertEqual( self.code(statement), "@bar\ndef foo(): pass\n", )
class ClassDefParserTest(CSTNodeTest): @data_provider(( # Simple classdef # pyre-fixme[6]: Incompatible parameter type { "node": cst.ClassDef(cst.Name("Foo"), cst.SimpleStatementSuite((cst.Pass(), ))), "code": "class Foo: pass\n", }, { "node": cst.ClassDef( cst.Name("Foo"), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(), rpar=cst.RightParen(), ), "code": "class Foo(): pass\n", }, # Positional arguments render test { "node": cst.ClassDef( cst.Name("Foo"), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(), bases=(cst.Arg(cst.Name("obj")), ), rpar=cst.RightParen(), ), "code": "class Foo(obj): pass\n", }, { "node": cst.ClassDef( cst.Name("Foo"), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(), bases=( cst.Arg( cst.Name("Bar"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( cst.Name("Baz"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg(cst.Name("object")), ), rpar=cst.RightParen(), ), "code": "class Foo(Bar, Baz, object): pass\n", }, # Keyword arguments render test { "node": cst.ClassDef( cst.Name("Foo"), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(), keywords=(cst.Arg( keyword=cst.Name("metaclass"), equal=cst.AssignEqual(), value=cst.Name("Bar"), ), ), rpar=cst.RightParen(), ), "code": "class Foo(metaclass = Bar): pass\n", }, # Iterator expansion render test { "node": cst.ClassDef( cst.Name("Foo"), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(), bases=(cst.Arg(star="*", value=cst.Name("one")), ), rpar=cst.RightParen(), ), "code": "class Foo(*one): pass\n", }, { "node": cst.ClassDef( cst.Name("Foo"), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(), bases=( cst.Arg( star="*", value=cst.Name("one"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( star="*", value=cst.Name("two"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg(star="*", value=cst.Name("three")), ), rpar=cst.RightParen(), ), "code": "class Foo(*one, *two, *three): pass\n", }, # Dictionary expansion render test { "node": cst.ClassDef( cst.Name("Foo"), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(), keywords=(cst.Arg(star="**", value=cst.Name("one")), ), rpar=cst.RightParen(), ), "code": "class Foo(**one): pass\n", }, { "node": cst.ClassDef( cst.Name("Foo"), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(), keywords=( cst.Arg( star="**", value=cst.Name("one"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( star="**", value=cst.Name("two"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg(star="**", value=cst.Name("three")), ), rpar=cst.RightParen(), ), "code": "class Foo(**one, **two, **three): pass\n", }, # Decorator render tests { "node": cst.ClassDef( cst.Name("Foo"), cst.SimpleStatementSuite((cst.Pass(), )), decorators=(cst.Decorator(cst.Name("foo")), ), lpar=cst.LeftParen(), rpar=cst.RightParen(), ), "code": "@foo\nclass Foo(): pass\n", "expected_position": CodeRange((2, 0), (2, 17)), }, { "node": cst.ClassDef( leading_lines=( cst.EmptyLine(), cst.EmptyLine(comment=cst.Comment("# leading comment 1")), ), decorators=( cst.Decorator(cst.Name("foo"), leading_lines=()), cst.Decorator( cst.Name("bar"), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment 2")), ), ), cst.Decorator( cst.Name("baz"), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment 3")), ), ), ), lines_after_decorators=(cst.EmptyLine( comment=cst.Comment("# class comment")), ), name=cst.Name("Foo"), body=cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(), rpar=cst.RightParen(), ), "code": "\n# leading comment 1\n@foo\n# leading comment 2\n@bar\n# leading comment 3\n@baz\n# class comment\nclass Foo(): pass\n", "expected_position": CodeRange((9, 0), (9, 17)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs, parser=parse_statement)