def test_parameters(self) -> None: # Test that we can insert a parameter into a function def normally. statement = parse_template_statement( "def foo({arg}): pass", arg=cst.Name("bar"), ) self.assertEqual( self.code(statement), "def foo(bar): pass\n", ) # Test that we can insert a parameter as a special case. statement = parse_template_statement( "def foo({arg}): pass", arg=cst.Param(cst.Name("bar")), ) self.assertEqual( self.code(statement), "def foo(bar): pass\n", ) # Test that we can insert a parameters list as a special case. statement = parse_template_statement( "def foo({args}): pass", args=cst.Parameters( (cst.Param(cst.Name("bar")),), ), ) self.assertEqual( self.code(statement), "def foo(bar): pass\n", ) # Test filling out multiple parameters statement = parse_template_statement( "def foo({args}): pass", args=cst.Parameters( params=( cst.Param(cst.Name("bar")), cst.Param(cst.Name("baz")), ), star_kwarg=cst.Param(cst.Name("rest")), ), ) self.assertEqual( self.code(statement), "def foo(bar, baz, **rest): pass\n", )
def test_deep_replace_complex(self) -> None: old_code = """ def a(): def b(): def c(): pass """ new_code = """ def a(): def b(): def d(): break """ module = cst.parse_module(dedent(old_code)) outer_fun = cst.ensure_type(module.body[0], cst.FunctionDef) middle_fun = cst.ensure_type( cst.ensure_type(outer_fun.body, cst.IndentedBlock).body[0], cst.FunctionDef) inner_fun = cst.ensure_type( cst.ensure_type(middle_fun.body, cst.IndentedBlock).body[0], cst.FunctionDef) new_module = cst.ensure_type( module.deep_replace( inner_fun, cst.FunctionDef( name=cst.Name("d"), params=cst.Parameters(), body=cst.SimpleStatementSuite(body=(cst.Break(), )), ), ), cst.Module, ) self.assertEqual(new_module.code, dedent(new_code))
def _reverse_params( node: cst.CSTNode, extraction: Dict[str, Union[cst.CSTNode, Sequence[cst.CSTNode]]], ) -> cst.CSTNode: return cst.ensure_type(node, cst.FunctionDef).with_changes( # pyre-ignore We know "params" is a Sequence[Parameters] but asserting that # to pyre is difficult. params=cst.Parameters( params=list(reversed(extraction["params"]))))
def parameters( self, type: typing.Literal["function", "classmethod", "method"] ) -> cst.Parameters: posonly_params = [ cst.Param(cst.Name(k), cst.Annotation(v.annotation)) for k, v in self.pos_only_required.items() ] + [ cst.Param(cst.Name(k), cst.Annotation(v.annotation), default=cst.Ellipsis()) for k, v in possibly_order_dict( self.pos_only_optional, self.pos_only_optional_ordering).items() ] if type == "classmethod": posonly_params.insert(0, cst.Param(cst.Name("cls"))) elif type == "method": posonly_params.insert(0, cst.Param(cst.Name("self"))) return cst.Parameters( posonly_params=posonly_params, params=[ cst.Param(cst.Name(k), cst.Annotation(v.annotation)) for k, v in self.pos_or_kw_required.items() ] + [ cst.Param(cst.Name(k), cst.Annotation(v.annotation), default=cst.Ellipsis()) for k, v in possibly_order_dict( self.pos_or_kw_optional, self.pos_or_kw_optional_ordering).items() ], star_arg=(cst.Param( cst.Name(self.var_pos[0]), cst.Annotation(self.var_pos[1].annotation), ) if self.var_pos else cst.MaybeSentinel.DEFAULT), star_kwarg=(cst.Param(cst.Name(self.var_kw[0]), cst.Annotation(self.var_kw[1].annotation)) if self.var_kw else None), kwonly_params=[ cst.Param(cst.Name(k), cst.Annotation(v.annotation)) for k, v in self.kw_only_required.items() ] + [ cst.Param(cst.Name(k), cst.Annotation(v.annotation), default=cst.Ellipsis()) for k, v in self.kw_only_optional.items() ], )
class FooterBehaviorTest(UnitTest): @data_provider({ # Literally the most basic example "simple_module": { "code": "\n", "expected_module": cst.Module(body=()) }, # A module with a header comment "header_only_module": { "code": "# This is a header comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[], ), }, # A module with a header and footer "simple_header_footer_module": { "code": "# This is a header comment\npass\n# This is a footer comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[cst.SimpleStatementLine([cst.Pass()])], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # A module which should have a footer comment taken from the # if statement's indented block. "simple_reparented_footer_module": { "code": "# This is a header comment\nif True:\n pass\n# This is a footer comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( header=cst.TrailingWhitespace(), body=[ cst.SimpleStatementLine( body=[cst.Pass()], trailing_whitespace=cst.TrailingWhitespace( ), ) ], ), ) ], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # Verifying that we properly parse and spread out footer comments to the # relative indents they go with. "complex_reparented_footer_module": { "code": ("# This is a header comment\nif True:\n if True:\n pass" + "\n # This is an inner indented block comment\n # This " + "is an outer indented block comment\n# This is a footer comment\n" ), "expected_module": cst.Module( body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.SimpleStatementLine( body=[cst.Pass()]) ], footer=[ cst.EmptyLine(comment=cst.Comment( value= "# This is an inner indented block comment" )) ], ), ) ], footer=[ cst.EmptyLine(comment=cst.Comment( value= "# This is an outer indented block comment" )) ], ), ) ], header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # Verify that comments belonging to statements are still owned even # after an indented block. "statement_comment_reparent": { "code": "if foo:\n return\n# comment\nx = 7\n", "expected_module": cst.Module(body=[ cst.If( test=cst.Name(value="foo"), body=cst.IndentedBlock(body=[ cst.SimpleStatementLine(body=[ cst.Return( whitespace_after_return=cst.SimpleWhitespace( value="")) ]) ]), ), cst.SimpleStatementLine( body=[ cst.Assign( targets=[ cst.AssignTarget(target=cst.Name(value="x")) ], value=cst.Integer(value="7"), ) ], leading_lines=[ cst.EmptyLine(comment=cst.Comment(value="# comment")) ], ), ]), }, # Verify that even if there are completely empty lines, we give all lines # up to and including the last line that's indented correctly. That way # comments that line up with indented block's indentation level aren't # parented to the next line just because there's a blank line or two # between them. "statement_comment_with_empty_lines": { "code": ("def foo():\n if True:\n pass\n\n # Empty " + "line before me\n\n else:\n pass\n"), "expected_module": cst.Module(body=[ cst.FunctionDef( name=cst.Name(value="foo"), params=cst.Parameters(), body=cst.IndentedBlock(body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.SimpleStatementLine(body=[cst.Pass()]) ], footer=[ cst.EmptyLine(indent=False), cst.EmptyLine(comment=cst.Comment( value="# Empty line before me")), ], ), orelse=cst.Else( body=cst.IndentedBlock(body=[ cst.SimpleStatementLine(body=[cst.Pass()]) ]), leading_lines=[cst.EmptyLine(indent=False)], ), ) ]), ) ]), }, }) def test_parsers(self, code: str, expected_module: cst.CSTNode) -> None: parsed_module = parse_module(dedent(code)) self.assertTrue( deep_equals(parsed_module, expected_module), msg= f"\n{parsed_module!r}\nis not deeply equal to \n{expected_module!r}", )
class AwaitTest(CSTNodeTest): @data_provider(( # Some simple calls { "node": cst.Await(cst.Name("test")), "code": "await test", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), "expected_position": None, }, { "node": cst.Await(cst.Call(cst.Name("test"))), "code": "await test()", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), "expected_position": None, }, # Whitespace { "node": cst.Await( cst.Name("test"), whitespace_after_await=cst.SimpleWhitespace(" "), lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "code": "( await test )", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), "expected_position": CodeRange((1, 2), (1, 13)), }, )) def test_valid_py37(self, **kwargs: Any) -> None: # We don't have sentinel nodes for atoms, so we know that 100% of atoms # can be parsed identically to their creation. self.validate_node(**kwargs) @data_provider(( # Some simple calls { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Expr(cst.Await(cst.Name("test"))), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n await test\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), "expected_position": None, }, { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Expr(cst.Await(cst.Call(cst.Name("test")))), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n await test()\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), "expected_position": None, }, # Whitespace { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine((cst.Expr( cst.Await( cst.Name("test"), whitespace_after_await=cst.SimpleWhitespace(" "), lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), )), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n ( await test )\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), "expected_position": None, }, )) def test_valid_py36(self, **kwargs: Any) -> None: # We don't have sentinel nodes for atoms, so we know that 100% of atoms # can be parsed identically to their creation. self.validate_node(**kwargs) @data_provider(( # Expression wrapping parenthesis rules { "get_node": (lambda: cst.Await(cst.Name("foo"), lpar=(cst.LeftParen(), ))), "expected_re": "left paren without right paren", }, { "get_node": (lambda: cst.Await(cst.Name("foo"), rpar=(cst.RightParen(), ))), "expected_re": "right paren without left paren", }, { "get_node": (lambda: cst.Await(cst.Name("foo"), whitespace_after_await=cst.SimpleWhitespace("")) ), "expected_re": "at least one space after await", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
class ForTest(CSTNodeTest): @data_provider(( # Simple for block { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "for target in iter(): pass\n", "parser": parse_statement, }, # Simple async for block { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), "code": "async for target in iter(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.7")), }, # Python 3.6 async for block { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n async for target in iter(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), }, # For block with else { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), cst.Else(cst.SimpleStatementSuite((cst.Pass(), ))), ), "code": "for target in iter(): pass\nelse: pass\n", "parser": parse_statement, }, # indentation { "node": DummyIndentedBlock( " ", cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), ), ), "code": " for target in iter(): pass\n", "parser": None, }, # for an indented body { "node": DummyIndentedBlock( " ", cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), ), ), "code": " for target in iter():\n pass\n", "parser": None, "expected_position": CodeRange((1, 4), (2, 12)), }, # leading_lines { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), cst.Else( cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# else comment")), ), ), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nfor target in iter():\n pass\n# else comment\nelse:\n pass\n", "parser": None, "expected_position": CodeRange((2, 0), (6, 8)), }, # Weird spacing rules { "node": cst.For( cst.Name("target", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), cst.Call( cst.Name("iter"), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_for=cst.SimpleWhitespace(""), whitespace_before_in=cst.SimpleWhitespace(""), whitespace_after_in=cst.SimpleWhitespace(""), ), "code": "for(target)in(iter()): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 27)), }, # Whitespace { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_for=cst.SimpleWhitespace(" "), whitespace_before_in=cst.SimpleWhitespace(" "), whitespace_after_in=cst.SimpleWhitespace(" "), whitespace_before_colon=cst.SimpleWhitespace(" "), ), "code": "for target in iter() : pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 31)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( { "get_node": lambda: cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_for=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after 'for' keyword", }, { "get_node": lambda: cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_before_in=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space before 'in' keyword", }, { "get_node": lambda: cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_in=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after 'in' keyword", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
class WithTest(CSTNodeTest): @data_provider(( # Simple with block { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with context_mgr(): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 24)), }, # Simple async with block { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), "code": "async with context_mgr(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.7")), }, # Python 3.6 async with block { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n async with context_mgr(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), }, # Multiple context managers { "node": cst.With( ( cst.WithItem(cst.Call(cst.Name("foo"))), cst.WithItem(cst.Call(cst.Name("bar"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with foo(), bar(): pass\n", "parser": None, }, { "node": cst.With( ( cst.WithItem( cst.Call(cst.Name("foo")), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.WithItem(cst.Call(cst.Name("bar"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with foo(), bar(): pass\n", "parser": parse_statement, }, # With block containing variable for context manager. { "node": cst.With( (cst.WithItem( cst.Call(cst.Name("context_mgr")), cst.AsName(cst.Name("ctx")), ), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with context_mgr() as ctx: pass\n", "parser": parse_statement, }, # indentation { "node": DummyIndentedBlock( " ", cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), ), "code": " with context_mgr(): pass\n", "parser": None, "expected_position": CodeRange((1, 4), (1, 28)), }, # with an indented body { "node": DummyIndentedBlock( " ", cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), ), ), "code": " with context_mgr():\n pass\n", "parser": None, "expected_position": CodeRange((1, 4), (2, 12)), }, # leading_lines { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nwith context_mgr(): pass\n", "parser": parse_statement, "expected_position": CodeRange((2, 0), (2, 24)), }, # Weird spacing rules { "node": cst.With( (cst.WithItem( cst.Call( cst.Name("context_mgr"), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), )), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(""), ), "code": "with(context_mgr()): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 25)), }, # Whitespace { "node": cst.With( (cst.WithItem( cst.Call(cst.Name("context_mgr")), cst.AsName( cst.Name("ctx"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(" "), whitespace_before_colon=cst.SimpleWhitespace(" "), ), "code": "with context_mgr() as ctx : pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 36)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( { "get_node": lambda: cst.With((), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), ))), "expected_re": "A With statement must have at least one WithItem", }, { "get_node": lambda: cst.With( (cst.WithItem( cst.Call(cst.Name("foo")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ") ), ), ), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), ), "expected_re": "The last WithItem in a With cannot have a trailing comma", }, { "get_node": lambda: cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after with keyword", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs) @data_provider(( { "code": "with a, b: pass", "parser": parse_statement_as(python_version="3.1"), "expect_success": True, }, { "code": "with a, b: pass", "parser": parse_statement_as(python_version="3.0"), "expect_success": False, }, )) def test_versions(self, **kwargs: Any) -> None: self.assert_parses(**kwargs)
class LambdaParserTest(CSTNodeTest): @data_provider(( # Simple lambda (cst.Lambda(cst.Parameters(), cst.Integer("5")), "lambda: 5"), # Test basic positional params ( cst.Lambda( cst.Parameters(params=( cst.Param( cst.Name("bar"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param(cst.Name("baz"), star=""), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), "lambda bar, baz: 5", ), # Test basic positional default params ( cst.Lambda( cst.Parameters(default_params=( cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("baz"), default=cst.Integer("5"), equal=cst.AssignEqual(), star="", ), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), 'lambda bar = "one", baz = 5: 5', ), # Mixed positional and default params. ( cst.Lambda( cst.Parameters( params=(cst.Param( cst.Name("bar"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), default_params=(cst.Param( cst.Name("baz"), default=cst.Integer("5"), equal=cst.AssignEqual(), star="", ), ), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), "lambda bar, baz = 5: 5", ), # Test kwonly params ( cst.Lambda( cst.Parameters( star_arg=cst.ParamStar(), kwonly_params=( cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param(cst.Name("baz"), star=""), ), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), 'lambda *, bar = "one", baz: 5', ), # Mixed params and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param( cst.Name("first"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("second"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), star_arg=cst.ParamStar(), kwonly_params=( cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("baz"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("biz"), default=cst.SimpleString('"two"'), equal=cst.AssignEqual(), star="", ), ), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), 'lambda first, second, *, bar = "one", baz, biz = "two": 5', ), # Mixed default_params and kwonly_params ( cst.Lambda( cst.Parameters( default_params=( cst.Param( cst.Name("first"), default=cst.Float("1.0"), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("second"), default=cst.Float("1.5"), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), star_arg=cst.ParamStar(), kwonly_params=( cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("baz"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("biz"), default=cst.SimpleString('"two"'), equal=cst.AssignEqual(), star="", ), ), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), 'lambda first = 1.0, second = 1.5, *, bar = "one", baz, biz = "two": 5', ), # Mixed params, default_params, and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param( cst.Name("first"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("second"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), default_params=( cst.Param( cst.Name("third"), default=cst.Float("1.0"), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("fourth"), default=cst.Float("1.5"), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), star_arg=cst.ParamStar(), kwonly_params=( cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("baz"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("biz"), default=cst.SimpleString('"two"'), equal=cst.AssignEqual(), star="", ), ), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), 'lambda first, second, third = 1.0, fourth = 1.5, *, bar = "one", baz, biz = "two": 5', ), # Test star_arg ( cst.Lambda( cst.Parameters( star_arg=cst.Param(cst.Name("params"), star="*")), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), "lambda *params: 5", ), # Typed star_arg, include kwonly_params ( cst.Lambda( cst.Parameters( star_arg=cst.Param( cst.Name("params"), star="*", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), kwonly_params=( cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("baz"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("biz"), default=cst.SimpleString('"two"'), equal=cst.AssignEqual(), star="", ), ), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), 'lambda *params, bar = "one", baz, biz = "two": 5', ), # Mixed params default_params, star_arg and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param( cst.Name("first"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("second"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), default_params=( cst.Param( cst.Name("third"), default=cst.Float("1.0"), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("fourth"), default=cst.Float("1.5"), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), ), star_arg=cst.Param( cst.Name("params"), star="*", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), kwonly_params=( cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), equal=cst.AssignEqual(), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("baz"), star="", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Param( cst.Name("biz"), default=cst.SimpleString('"two"'), equal=cst.AssignEqual(), star="", ), ), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), 'lambda first, second, third = 1.0, fourth = 1.5, *params, bar = "one", baz, biz = "two": 5', ), # Test star_arg and star_kwarg ( cst.Lambda( cst.Parameters( star_kwarg=cst.Param(cst.Name("kwparams"), star="**")), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), "lambda **kwparams: 5", ), # Test star_arg and kwarg ( cst.Lambda( cst.Parameters( star_arg=cst.Param( cst.Name("params"), star="*", comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), star_kwarg=cst.Param(cst.Name("kwparams"), star="**"), ), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(" "), ), "lambda *params, **kwparams: 5", ), # Inner whitespace ( cst.Lambda( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), params=cst.Parameters(), colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), body=cst.Integer("5"), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( lambda : 5 )", ), )) def test_valid(self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None) -> None: self.validate_node(node, code, parse_expression, position)
class LambdaCreationTest(CSTNodeTest): @data_provider(( # Simple lambda (cst.Lambda(cst.Parameters(), cst.Integer("5")), "lambda: 5"), # Test basic positional params ( cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("bar")), cst.Param(cst.Name("baz")))), cst.Integer("5"), ), "lambda bar, baz: 5", ), # Test basic positional default params ( cst.Lambda( cst.Parameters(default_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz"), default=cst.Integer("5")), )), cst.Integer("5"), ), 'lambda bar = "one", baz = 5: 5', ), # Mixed positional and default params. ( cst.Lambda( cst.Parameters( params=(cst.Param(cst.Name("bar")), ), default_params=(cst.Param(cst.Name("baz"), default=cst.Integer("5")), ), ), cst.Integer("5"), ), "lambda bar, baz = 5: 5", ), # Test kwonly params ( cst.Lambda( cst.Parameters(kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), )), cst.Integer("5"), ), 'lambda *, bar = "one", baz: 5', ), # Mixed params and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param(cst.Name("first")), cst.Param(cst.Name("second")), ), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda first, second, *, bar = "one", baz, biz = "two": 5', ), # Mixed default_params and kwonly_params ( cst.Lambda( cst.Parameters( default_params=( cst.Param(cst.Name("first"), default=cst.Float("1.0")), cst.Param(cst.Name("second"), default=cst.Float("1.5")), ), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda first = 1.0, second = 1.5, *, bar = "one", baz, biz = "two": 5', ), # Mixed params, default_params, and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param(cst.Name("first")), cst.Param(cst.Name("second")), ), default_params=( cst.Param(cst.Name("third"), default=cst.Float("1.0")), cst.Param(cst.Name("fourth"), default=cst.Float("1.5")), ), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda first, second, third = 1.0, fourth = 1.5, *, bar = "one", baz, biz = "two": 5', CodeRange((1, 0), (1, 84)), ), # Test star_arg ( cst.Lambda( cst.Parameters(star_arg=cst.Param(cst.Name("params"))), cst.Integer("5"), ), "lambda *params: 5", ), # Typed star_arg, include kwonly_params ( cst.Lambda( cst.Parameters( star_arg=cst.Param(cst.Name("params")), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda *params, bar = "one", baz, biz = "two": 5', ), # Mixed params default_params, star_arg and kwonly_params ( cst.Lambda( cst.Parameters( params=( cst.Param(cst.Name("first")), cst.Param(cst.Name("second")), ), default_params=( cst.Param(cst.Name("third"), default=cst.Float("1.0")), cst.Param(cst.Name("fourth"), default=cst.Float("1.5")), ), star_arg=cst.Param(cst.Name("params")), kwonly_params=( cst.Param(cst.Name("bar"), default=cst.SimpleString('"one"')), cst.Param(cst.Name("baz")), cst.Param(cst.Name("biz"), default=cst.SimpleString('"two"')), ), ), cst.Integer("5"), ), 'lambda first, second, third = 1.0, fourth = 1.5, *params, bar = "one", baz, biz = "two": 5', ), # Test star_arg and star_kwarg ( cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("kwparams"))), cst.Integer("5"), ), "lambda **kwparams: 5", ), # Test star_arg and kwarg ( cst.Lambda( cst.Parameters( star_arg=cst.Param(cst.Name("params")), star_kwarg=cst.Param(cst.Name("kwparams")), ), cst.Integer("5"), ), "lambda *params, **kwparams: 5", ), # Inner whitespace ( cst.Lambda( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), whitespace_after_lambda=cst.SimpleWhitespace(" "), params=cst.Parameters(), colon=cst.Colon(whitespace_after=cst.SimpleWhitespace(" ")), body=cst.Integer("5"), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( lambda : 5 )", CodeRange((1, 2), (1, 13)), ), )) def test_valid(self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None) -> None: self.validate_node(node, code, expected_position=position) @data_provider(( ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("arg")), )), cst.Integer("5"), lpar=(cst.LeftParen(), ), ), "left paren without right paren", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("arg")), )), cst.Integer("5"), rpar=(cst.RightParen(), ), ), "right paren without left paren", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("arg")), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(default_params=(cst.Param( cst.Name("arg"), default=cst.Integer("5")), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(star_arg=cst.Param(cst.Name("arg"))), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(kwonly_params=(cst.Param(cst.Name("arg")), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"))), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "at least one space after lambda", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), equal=cst.AssignEqual())), cst.Integer("5"), ), "Must have a default when specifying an AssignEqual.", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), star="***")), cst.Integer("5"), ), r"Must specify either '', '\*' or '\*\*' for star.", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"')), )), cst.Integer("5"), ), "Cannot have defaults for params", ), ( lambda: cst.Lambda( cst.Parameters(default_params=(cst.Param(cst.Name("bar")), )), cst.Integer("5"), ), "Must have defaults for default_params", ), ( lambda: cst.Lambda(cst.Parameters(star_arg=cst.ParamStar()), cst.Integer("5")), "Must have at least one kwonly param if ParamStar is used.", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param(cst.Name("bar"), star="*"), ) ), cst.Integer("5"), ), "Expecting a star prefix of ''", ), ( lambda: cst.Lambda( cst.Parameters(default_params=(cst.Param( cst.Name("bar"), default=cst.SimpleString('"one"'), star="*", ), )), cst.Integer("5"), ), "Expecting a star prefix of ''", ), ( lambda: cst.Lambda( cst.Parameters(kwonly_params=(cst.Param(cst.Name("bar"), star="*"), )), cst.Integer("5"), ), "Expecting a star prefix of ''", ), ( lambda: cst.Lambda( cst.Parameters(star_arg=cst.Param(cst.Name("bar"), star="**")), cst.Integer("5"), ), r"Expecting a star prefix of '\*'", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("bar"), star="*") ), cst.Integer("5"), ), r"Expecting a star prefix of '\*\*'", ), ( lambda: cst.Lambda( cst.Parameters(params=(cst.Param( cst.Name("arg"), annotation=cst.Annotation(cst.Name("str")), ), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), ( lambda: cst.Lambda( cst.Parameters(default_params=(cst.Param( cst.Name("arg"), default=cst.Integer("5"), annotation=cst.Annotation(cst.Name("str")), ), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), ( lambda: cst.Lambda( cst.Parameters(star_arg=cst.Param(cst.Name("arg"), annotation=cst.Annotation( cst.Name("str")))), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), ( lambda: cst.Lambda( cst.Parameters(kwonly_params=(cst.Param( cst.Name("arg"), annotation=cst.Annotation(cst.Name("str")), ), )), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), ( lambda: cst.Lambda( cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"), annotation=cst.Annotation( cst.Name("str")))), cst.Integer("5"), whitespace_after_lambda=cst.SimpleWhitespace(""), ), "Lambda params cannot have type annotations", ), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
class SimpleCompTest(CSTNodeTest): @data_provider([ # simple GeneratorExp { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c"))), "code": "(a for b in c)", "parser": parse_expression, "expected_position": CodeRange((1, 1), (1, 13)), }, # simple ListComp { "node": cst.ListComp(cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c"))), "code": "[a for b in c]", "parser": parse_expression, "expected_position": CodeRange((1, 0), (1, 14)), }, # simple SetComp { "node": cst.SetComp(cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c"))), "code": "{a for b in c}", "parser": parse_expression, }, # async GeneratorExp { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), asynchronous=cst.Asynchronous(), ), ), "code": "(a async for b in c)", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), }, # Python 3.6 async GeneratorExp { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine((cst.Expr( cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), asynchronous=cst.Asynchronous(), ), )), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n (a async for b in c)\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), }, # a generator doesn't have to own it's own parenthesis { "node": cst.Call( cst.Name("func"), [ cst.Arg( cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[], rpar=[], )) ], ), "code": "func(a for b in c)", "parser": parse_expression, }, # add a few 'if' clauses { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf(cst.Name("d")), cst.CompIf(cst.Name("e")), cst.CompIf(cst.Name("f")), ], ), ), "code": "(a for b in c if d if e if f)", "parser": parse_expression, "expected_position": CodeRange((1, 1), (1, 28)), }, # nested/inner for-in clause { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), inner_for_in=cst.CompFor(target=cst.Name("d"), iter=cst.Name("e")), ), ), "code": "(a for b in c for d in e)", "parser": parse_expression, }, # nested/inner for-in clause with an 'if' clause { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[cst.CompIf(cst.Name("d"))], inner_for_in=cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ), "code": "(a for b in c if d for e in f)", "parser": parse_expression, }, # custom whitespace { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf( cst.Name("d"), whitespace_before=cst.SimpleWhitespace("\t"), whitespace_before_test=cst.SimpleWhitespace( "\t\t"), ) ], whitespace_before=cst.SimpleWhitespace(" "), whitespace_after_for=cst.SimpleWhitespace(" "), whitespace_before_in=cst.SimpleWhitespace(" "), whitespace_after_in=cst.SimpleWhitespace(" "), ), lpar=[ cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f")) ], rpar=[ cst.RightParen( whitespace_before=cst.SimpleWhitespace("\f\f")) ], ), "code": "(\fa for b in c\tif\t\td\f\f)", "parser": parse_expression, "expected_position": CodeRange((1, 2), (1, 30)), }, # custom whitespace around ListComp's brackets { "node": cst.ListComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace("\t")), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace("\t\t")), lpar=[ cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f")) ], rpar=[ cst.RightParen( whitespace_before=cst.SimpleWhitespace("\f\f")) ], ), "code": "(\f[\ta for b in c\t\t]\f\f)", "parser": parse_expression, "expected_position": CodeRange((1, 2), (1, 19)), }, # custom whitespace around SetComp's braces { "node": cst.SetComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lbrace=cst.LeftCurlyBrace( whitespace_after=cst.SimpleWhitespace("\t")), rbrace=cst.RightCurlyBrace( whitespace_before=cst.SimpleWhitespace("\t\t")), lpar=[ cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f")) ], rpar=[ cst.RightParen( whitespace_before=cst.SimpleWhitespace("\f\f")) ], ), "code": "(\f{\ta for b in c\t\t}\f\f)", "parser": parse_expression, }, # no whitespace between elements { "node": cst.GeneratorExp( cst.Name("a", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), cst.CompFor( target=cst.Name("b", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), iter=cst.Name("c", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), ifs=[ cst.CompIf( cst.Name("d", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), whitespace_before=cst.SimpleWhitespace(""), whitespace_before_test=cst.SimpleWhitespace(""), ) ], inner_for_in=cst.CompFor( target=cst.Name("e", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), iter=cst.Name("f", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), whitespace_before=cst.SimpleWhitespace(""), whitespace_after_for=cst.SimpleWhitespace(""), whitespace_before_in=cst.SimpleWhitespace(""), whitespace_after_in=cst.SimpleWhitespace(""), ), whitespace_before=cst.SimpleWhitespace(""), whitespace_after_for=cst.SimpleWhitespace(""), whitespace_before_in=cst.SimpleWhitespace(""), whitespace_after_in=cst.SimpleWhitespace(""), ), lpar=[cst.LeftParen()], rpar=[cst.RightParen()], ), "code": "((a)for(b)in(c)if(d)for(e)in(f))", "parser": parse_expression, "expected_position": CodeRange((1, 1), (1, 31)), }, # no whitespace before/after GeneratorExp is valid { "node": cst.Comparison( cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), ), [ cst.ComparisonTarget( cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), cst.GeneratorExp( cst.Name("d"), cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ) ], ), "code": "(a for b in c)is(d for e in f)", "parser": parse_expression, }, # no whitespace before/after ListComp is valid { "node": cst.Comparison( cst.ListComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), ), [ cst.ComparisonTarget( cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), cst.ListComp( cst.Name("d"), cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ) ], ), "code": "[a for b in c]is[d for e in f]", "parser": parse_expression, }, # no whitespace before/after SetComp is valid { "node": cst.Comparison( cst.SetComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), ), [ cst.ComparisonTarget( cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), cst.SetComp( cst.Name("d"), cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ) ], ), "code": "{a for b in c}is{d for e in f}", "parser": parse_expression, }, ]) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), ( lambda: cst.ListComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), ( lambda: cst.SetComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_before=cst.SimpleWhitespace(""), ), ), "Must have at least one space before 'for' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), asynchronous=cst.Asynchronous(), whitespace_before=cst.SimpleWhitespace(""), ), ), "Must have at least one space before 'async' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_after_for=cst.SimpleWhitespace(""), ), ), "Must have at least one space after 'for' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_before_in=cst.SimpleWhitespace(""), ), ), "Must have at least one space before 'in' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_after_in=cst.SimpleWhitespace(""), ), ), "Must have at least one space after 'in' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf( cst.Name("d"), whitespace_before=cst.SimpleWhitespace(""), ) ], ), ), "Must have at least one space before 'if' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf( cst.Name("d"), whitespace_before_test=cst.SimpleWhitespace(""), ) ], ), ), "Must have at least one space after 'if' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), inner_for_in=cst.CompFor( target=cst.Name("d"), iter=cst.Name("e"), whitespace_before=cst.SimpleWhitespace(""), ), ), ), "Must have at least one space before 'for' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), inner_for_in=cst.CompFor( target=cst.Name("d"), iter=cst.Name("e"), asynchronous=cst.Asynchronous(), whitespace_before=cst.SimpleWhitespace(""), ), ), ), "Must have at least one space before 'async' keyword.", ), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
class WithTest(CSTNodeTest): maxDiff: int = 2000 @data_provider(( # Simple with block { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with context_mgr(): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 24)), }, # Simple async with block { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), "code": "async with context_mgr(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.7")), }, # Python 3.6 async with block { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n async with context_mgr(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), }, # Multiple context managers { "node": cst.With( ( cst.WithItem(cst.Call(cst.Name("foo"))), cst.WithItem(cst.Call(cst.Name("bar"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with foo(), bar(): pass\n", "parser": None, }, { "node": cst.With( ( cst.WithItem( cst.Call(cst.Name("foo")), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.WithItem(cst.Call(cst.Name("bar"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with foo(), bar(): pass\n", "parser": parse_statement, }, # With block containing variable for context manager. { "node": cst.With( (cst.WithItem( cst.Call(cst.Name("context_mgr")), cst.AsName(cst.Name("ctx")), ), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with context_mgr() as ctx: pass\n", "parser": parse_statement, }, # indentation { "node": DummyIndentedBlock( " ", cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), ), "code": " with context_mgr(): pass\n", "parser": None, "expected_position": CodeRange((1, 4), (1, 28)), }, # with an indented body { "node": DummyIndentedBlock( " ", cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), ), ), "code": " with context_mgr():\n pass\n", "parser": None, "expected_position": CodeRange((1, 4), (2, 12)), }, # leading_lines { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nwith context_mgr(): pass\n", "parser": parse_statement, "expected_position": CodeRange((2, 0), (2, 24)), }, # Whitespace { "node": cst.With( (cst.WithItem( cst.Call(cst.Name("context_mgr")), cst.AsName( cst.Name("ctx"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(" "), whitespace_before_colon=cst.SimpleWhitespace(" "), ), "code": "with context_mgr() as ctx : pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 36)), }, # Weird spacing rules, that parse differently depending on whether # we are using a grammar that included parenthesized with statements. { "node": cst.With( (cst.WithItem( cst.Call( cst.Name("context_mgr"), lpar=() if is_native() else (cst.LeftParen(), ), rpar=() if is_native() else (cst.RightParen(), ), )), ), cst.SimpleStatementSuite((cst.Pass(), )), lpar=(cst.LeftParen() if is_native() else MaybeSentinel.DEFAULT), rpar=(cst.RightParen() if is_native() else MaybeSentinel.DEFAULT), whitespace_after_with=cst.SimpleWhitespace(""), ), "code": "with(context_mgr()): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 25)), }, # Multi-line parenthesized with. { "node": cst.With( ( cst.WithItem( cst.Call(cst.Name("foo")), comma=cst. Comma(whitespace_after=cst.ParenthesizedWhitespace( first_line=cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(value="", ), comment=None, newline=cst.Newline(value=None, ), ), empty_lines=[], indent=True, last_line=cst.SimpleWhitespace(value=" ", ), )), ), cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()), ), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), rpar=cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), "code": ("with ( foo(),\n" " bar(), ): pass\n"), # noqa "parser": parse_statement if is_native() else None, "expected_position": CodeRange((1, 0), (2, 21)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( { "get_node": lambda: cst.With((), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), ))), "expected_re": "A With statement must have at least one WithItem", }, { "get_node": lambda: cst.With( (cst.WithItem( cst.Call(cst.Name("foo")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ") ), ), ), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), ), "expected_re": "The last WithItem in an unparenthesized With cannot " + "have a trailing comma.", }, { "get_node": lambda: cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after with keyword", }, { "get_node": lambda: cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(""), lpar=cst.LeftParen(), ), "expected_re": "Do not mix concrete LeftParen/RightParen with " + "MaybeSentinel", }, { "get_node": lambda: cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(""), rpar=cst.RightParen(), ), "expected_re": "Do not mix concrete LeftParen/RightParen with " + "MaybeSentinel", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs) @data_provider(( { "code": "with a, b: pass", "parser": parse_statement_as(python_version="3.1"), "expect_success": True, }, { "code": "with a, b: pass", "parser": parse_statement_as(python_version="3.0"), "expect_success": False, }, )) def test_versions(self, **kwargs: Any) -> None: if is_native() and not kwargs.get("expect_success", True): self.skipTest("parse errors are disabled for native parser") self.assert_parses(**kwargs) def test_adding_parens(self) -> None: node = cst.With( ( cst.WithItem( cst.Call(cst.Name("foo")), comma=cst.Comma( whitespace_after=cst.ParenthesizedWhitespace(), ), ), cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()), ), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")), ) module = cst.Module([]) self.assertEqual( module.code_for_node(node), ("with ( foo(),\n" "bar(), ): pass\n") # noqa )
'async': lambda _: {cst.FunctionDef} } nesting_op_children_getter = { '.': lambda node: node.body.body if isinstance(node, cst.ClassDef) else (), '(': lambda node: node.params.params if isinstance(node, cst.FunctionDef) else (), ',': getNextParam, None: lambda node: node.children } per_path_default_kwargs = RelativePathDict({ (cst.ClassDef, cst.FunctionDef): lambda path: { # TODO: if node is not decorated as staticmethod or class method! 'params': cst.Parameters(params=[cst.Param(cst.Name('self'))]) }, }) # NOTE: it is bad design to allow users to craft coincidentally unambiguous # scope expressions, instead, it would be better to have a plenty of unambigous # property keys (i.e. func, class, async, static, etc) for the user to craft with def possibleElemTypes(scope_expr: ScopeExpr) -> Set[cst.CSTNode]: return reduce(and_, map(lambda key, val: possible_node_classes_per_prop[key](val), scope_expr.properties.keys(), scope_expr.properties.values()), node_types)
def compile_graph(graph: GraphicalModel, namespace: dict, fn_name): """Compile MCX's graph into a python (executable) function.""" # Model arguments are passed in the following order: # 1. (samplers only) rng_key; # 2. (logpdf only) random variables, in the order in which they appear in the model. # 3. (all) the model's arguments and keyword arguments. maybe_rng_key = [ compile_placeholder(node, graph) for node in graph.placeholders if node.name == "rng_key" ] maybe_random_variables = [ compile_placeholder(node, graph) for node in reversed(list(graph.placeholders)) if node.is_random_variable ] model_args = [ compile_placeholder(node, graph) for node in graph.placeholders if not node.is_random_variable and node.name != "rng_key" and not node.has_default ] model_kwargs = [ compile_placeholder(node, graph) for node in graph.placeholders if not node.is_random_variable and node.name != "rng_key" and node.has_default ] args = maybe_rng_key + model_args + maybe_random_variables + model_kwargs # Every statement in the function corresponds to either a constant definition or # a variable assignment. We use a topological sort to respect the # dependency order. stmts = [] returns = [] for node in nx.topological_sort(graph): if node.name is None: continue if isinstance(node, Constant): stmt = cst.SimpleStatementLine(body=[ cst.Assign( targets=[ cst.AssignTarget(target=cst.Name(value=node.name)) ], value=node.cst_generator(), ) ]) stmts.append(stmt) if isinstance(node, Op): stmt = cst.SimpleStatementLine(body=[ cst.Assign( targets=[ cst.AssignTarget(target=cst.Name(value=node.name)) ], value=compile_op(node, graph), ) ]) stmts.append(stmt) if node.is_returned: returns.append( cst.SimpleStatementLine( body=[cst.Return(value=cst.Name(value=node.name))])) # Assemble the function's CST using the previously translated nodes. ast_fn = cst.Module(body=[ cst.FunctionDef( name=cst.Name(value=fn_name), params=cst.Parameters(params=args), body=cst.IndentedBlock(body=stmts + returns), ) ]) code = ast_fn.code exec(code, namespace) fn = namespace[fn_name] return fn, code