def test_suite(self) -> None: # Test that we can insert various types of statement suites into a # spot accepting a suite. module = parse_template_module( "if x is True: {suite}\n", suite=cst.SimpleStatementSuite(body=(cst.Pass(),),), ) self.assertEqual( module.code, "if x is True: pass\n", ) module = parse_template_module( "if x is True: {suite}\n", suite=cst.IndentedBlock(body=(cst.SimpleStatementLine((cst.Pass(),),),),), ) self.assertEqual( module.code, "if x is True:\n pass\n", ) module = parse_template_module( "if x is True:\n {suite}\n", suite=cst.SimpleStatementSuite(body=(cst.Pass(),),), ) self.assertEqual( module.code, "if x is True: pass\n", ) module = parse_template_module( "if x is True:\n {suite}\n", suite=cst.IndentedBlock(body=(cst.SimpleStatementLine((cst.Pass(),),),),), ) self.assertEqual( module.code, "if x is True:\n pass\n", )
def leave_SimpleStatementSuite( self, original_node: cst.SimpleStatementSuite, updated_node: cst.SimpleStatementSuite, ) -> cst.IndentedBlock: body = tuple(to_stmt(stmt) for stmt in updated_node.body) return cst.IndentedBlock(body=body).visit(self)
def _( self, original_node: cst.If, updated_node: cst.If, ) -> cst.If: orelse = cst.Else(body=cst.IndentedBlock(body=[updated_node.orelse])) updated_node = updated_node.with_changes(orelse=orelse) return updated_node
def test_module_config_for_parsing(self) -> None: module = parse_module("pass\r") statement = parse_statement("if True:\r pass", config=module.config_for_parsing) self.assertEqual( statement, cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[cst.SimpleStatementLine(body=[cst.Pass()])], header=cst.TrailingWhitespace(newline=cst.Newline( # This would be "\r" if we didn't pass the module config forward. value=None)), ), ), )
def function_def( self, name: str, type: typing.Literal["function", "classmethod", "method"], indent=0, overload=False, ) -> cst.FunctionDef: decorators: typing.List[cst.Decorator] = [] if overload: decorators.append(cst.Decorator(cst.Name("overload"))) if type == "classmethod": decorators.append(cst.Decorator(cst.Name("classmethod"))) return cst.FunctionDef( cst.Name(name), self.parameters(type), cst.IndentedBlock( [cst.SimpleStatementLine([s]) for s in self.body(indent)]), decorators, self.return_type_annotation)
def _get_assert_replacement(self, node: cst.Assert): message = node.msg or str(cst.Module(body=[node]).code) return cst.If( test=cst.UnaryOperation( operator=cst.Not(), expression=node.test, # Todo: parenthesize? ), body=cst.IndentedBlock(body=[ cst.SimpleStatementLine(body=[ cst.Raise(exc=cst.Call( func=cst.Name(value="AssertionError", ), args=[ cst.Arg(value=cst.SimpleString(value=repr(message), ), ), ], ), ), ]), ], ), )
class IfTest(CSTNodeTest): @data_provider(( # Simple if without elif or else { "node": cst.If(cst.Name("conditional"), cst.SimpleStatementSuite((cst.Pass(), ))), "code": "if conditional: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 20)), }, # else clause { "node": cst.If( cst.Name("conditional"), cst.SimpleStatementSuite((cst.Pass(), )), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(), ))), ), "code": "if conditional: pass\nelse: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (2, 10)), }, # elif clause { "node": cst.If( cst.Name("conditional"), cst.SimpleStatementSuite((cst.Pass(), )), orelse=cst.If( cst.Name("other_conditional"), cst.SimpleStatementSuite((cst.Pass(), )), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(), ))), ), ), "code": "if conditional: pass\nelif other_conditional: pass\nelse: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (3, 10)), }, # indentation { "node": DummyIndentedBlock( " ", cst.If( cst.Name("conditional"), cst.SimpleStatementSuite((cst.Pass(), )), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(), ))), ), ), "code": " if conditional: pass\n else: pass\n", "parser": None, "expected_position": CodeRange((1, 4), (2, 14)), }, # with an indented body { "node": DummyIndentedBlock( " ", cst.If( cst.Name("conditional"), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), ), ), "code": " if conditional:\n pass\n", "parser": None, "expected_position": CodeRange((1, 4), (2, 12)), }, # leading_lines { "node": cst.If( cst.Name("conditional"), cst.SimpleStatementSuite((cst.Pass(), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nif conditional: pass\n", "parser": parse_statement, "expected_position": CodeRange((2, 0), (2, 20)), }, # whitespace before/after test and else { "node": cst.If( cst.Name("conditional"), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_before_test=cst.SimpleWhitespace(" "), whitespace_after_test=cst.SimpleWhitespace(" "), orelse=cst.Else( cst.SimpleStatementSuite((cst.Pass(), )), whitespace_before_colon=cst.SimpleWhitespace(" "), ), ), "code": "if conditional : pass\nelse : pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (2, 11)), }, # empty lines between if/elif/else clauses, not captured by the suite. { "node": cst.If( cst.Name("test_a"), cst.SimpleStatementSuite((cst.Pass(), )), orelse=cst.If( cst.Name("test_b"), cst.SimpleStatementSuite((cst.Pass(), )), leading_lines=(cst.EmptyLine(), ), orelse=cst.Else( cst.SimpleStatementSuite((cst.Pass(), )), leading_lines=(cst.EmptyLine(), ), ), ), ), "code": "if test_a: pass\n\nelif test_b: pass\n\nelse: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (5, 10)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs)
def leave_Module(self, original_node: libcst.Module, updated_node: libcst.Module) -> libcst.Module: # Don't try to modify if we have nothing to do if (not self.module_imports and not self.module_mapping and not self.module_aliases and not self.alias_mapping): return updated_node # First, find the insertion point for imports ( statements_before_imports, statements_until_add_imports, statements_after_imports, ) = self._split_module(original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) # Mapping of modules we're adding to the object with and without alias they should import module_and_alias_mapping = defaultdict(list) for module, aliases in self.alias_mapping.items(): module_and_alias_mapping[module].extend(aliases) for module, imports in self.module_mapping.items(): module_and_alias_mapping[module].extend([(object, None) for object in imports]) module_and_alias_mapping = { module: sorted(aliases) for module, aliases in module_and_alias_mapping.items() } import_cycle_safe_module_names = [ 'mypy_extensions', 'typing', 'typing_extensions', ] type_checking_cond_import = parse_statement( f"from typing import TYPE_CHECKING", config=updated_node.config_for_parsing, ) type_checking_cond_statement = libcst.If( test=libcst.Name("TYPE_CHECKING"), body=libcst.IndentedBlock(body=[ parse_statement( f"from {module} import " + ", ".join([ obj if alias is None else f"{obj} as {alias}" for (obj, alias) in aliases ]), config=updated_node.config_for_parsing, ) for module, aliases in module_and_alias_mapping.items() if module != "__future__" and module not in import_cycle_safe_module_names ], ), ) if not type_checking_cond_statement.body.body: type_checking_cond_statement = libcst.EmptyLine() type_checking_cond_import = libcst.EmptyLine() # import ptvsd; ptvsd.set_trace() # Now, add all of the imports we need! return updated_node.with_changes(body=( *statements_before_imports, *[ parse_statement( f"from {module} import " + ", ".join([ obj if alias is None else f"{obj} as {alias}" for (obj, alias) in aliases ]), config=updated_node.config_for_parsing, ) for module, aliases in module_and_alias_mapping.items() if module == "__future__" ], *statements_until_add_imports, *[ parse_statement(f"import {module}", config=updated_node.config_for_parsing) for module in sorted(self.module_imports) ], *[ parse_statement( f"import {module} as {asname}", config=updated_node.config_for_parsing, ) for (module, asname) in self.module_aliases.items() ], # TODO: 可以进一步用 `from __future__ import annotations` 解决forward ref, 这里加也可以,用其他工具也可以 type_checking_cond_import, type_checking_cond_statement, *[ parse_statement( f"from {module} import " + ", ".join([ obj if alias is None else f"{obj} as {alias}" for (obj, alias) in aliases ]), config=updated_node.config_for_parsing, ) for module, aliases in module_and_alias_mapping.items() if module != "__future__" and module in import_cycle_safe_module_names and not module.startswith("monkeytype") ], *statements_after_imports, ))
class AwaitTest(CSTNodeTest): @data_provider(( # Some simple calls { "node": cst.Await(cst.Name("test")), "code": "await test", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), "expected_position": None, }, { "node": cst.Await(cst.Call(cst.Name("test"))), "code": "await test()", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), "expected_position": None, }, # Whitespace { "node": cst.Await( cst.Name("test"), whitespace_after_await=cst.SimpleWhitespace(" "), lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "code": "( await test )", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), "expected_position": CodeRange((1, 2), (1, 13)), }, )) def test_valid_py37(self, **kwargs: Any) -> None: # We don't have sentinel nodes for atoms, so we know that 100% of atoms # can be parsed identically to their creation. self.validate_node(**kwargs) @data_provider(( # Some simple calls { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Expr(cst.Await(cst.Name("test"))), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n await test\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), "expected_position": None, }, { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Expr(cst.Await(cst.Call(cst.Name("test")))), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n await test()\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), "expected_position": None, }, # Whitespace { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine((cst.Expr( cst.Await( cst.Name("test"), whitespace_after_await=cst.SimpleWhitespace(" "), lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), )), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n ( await test )\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), "expected_position": None, }, )) def test_valid_py36(self, **kwargs: Any) -> None: # We don't have sentinel nodes for atoms, so we know that 100% of atoms # can be parsed identically to their creation. self.validate_node(**kwargs) @data_provider(( # Expression wrapping parenthesis rules { "get_node": (lambda: cst.Await(cst.Name("foo"), lpar=(cst.LeftParen(), ))), "expected_re": "left paren without right paren", }, { "get_node": (lambda: cst.Await(cst.Name("foo"), rpar=(cst.RightParen(), ))), "expected_re": "right paren without left paren", }, { "get_node": (lambda: cst.Await(cst.Name("foo"), whitespace_after_await=cst.SimpleWhitespace("")) ), "expected_re": "at least one space after await", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
class ForTest(CSTNodeTest): @data_provider(( # Simple for block { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "for target in iter(): pass\n", "parser": parse_statement, }, # Simple async for block { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), "code": "async for target in iter(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.7")), }, # Python 3.6 async for block { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n async for target in iter(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), }, # For block with else { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), cst.Else(cst.SimpleStatementSuite((cst.Pass(), ))), ), "code": "for target in iter(): pass\nelse: pass\n", "parser": parse_statement, }, # indentation { "node": DummyIndentedBlock( " ", cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), ), ), "code": " for target in iter(): pass\n", "parser": None, }, # for an indented body { "node": DummyIndentedBlock( " ", cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), ), ), "code": " for target in iter():\n pass\n", "parser": None, "expected_position": CodeRange((1, 4), (2, 12)), }, # leading_lines { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), cst.Else( cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# else comment")), ), ), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nfor target in iter():\n pass\n# else comment\nelse:\n pass\n", "parser": None, "expected_position": CodeRange((2, 0), (6, 8)), }, # Weird spacing rules { "node": cst.For( cst.Name("target", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), cst.Call( cst.Name("iter"), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_for=cst.SimpleWhitespace(""), whitespace_before_in=cst.SimpleWhitespace(""), whitespace_after_in=cst.SimpleWhitespace(""), ), "code": "for(target)in(iter()): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 27)), }, # Whitespace { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_for=cst.SimpleWhitespace(" "), whitespace_before_in=cst.SimpleWhitespace(" "), whitespace_after_in=cst.SimpleWhitespace(" "), whitespace_before_colon=cst.SimpleWhitespace(" "), ), "code": "for target in iter() : pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 31)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( { "get_node": lambda: cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_for=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after 'for' keyword", }, { "get_node": lambda: cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_before_in=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space before 'in' keyword", }, { "get_node": lambda: cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_in=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after 'in' keyword", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
class WithTest(CSTNodeTest): @data_provider(( # Simple with block { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with context_mgr(): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 24)), }, # Simple async with block { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), "code": "async with context_mgr(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.7")), }, # Python 3.6 async with block { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n async with context_mgr(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), }, # Multiple context managers { "node": cst.With( ( cst.WithItem(cst.Call(cst.Name("foo"))), cst.WithItem(cst.Call(cst.Name("bar"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with foo(), bar(): pass\n", "parser": None, }, { "node": cst.With( ( cst.WithItem( cst.Call(cst.Name("foo")), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.WithItem(cst.Call(cst.Name("bar"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with foo(), bar(): pass\n", "parser": parse_statement, }, # With block containing variable for context manager. { "node": cst.With( (cst.WithItem( cst.Call(cst.Name("context_mgr")), cst.AsName(cst.Name("ctx")), ), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with context_mgr() as ctx: pass\n", "parser": parse_statement, }, # indentation { "node": DummyIndentedBlock( " ", cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), ), "code": " with context_mgr(): pass\n", "parser": None, "expected_position": CodeRange((1, 4), (1, 28)), }, # with an indented body { "node": DummyIndentedBlock( " ", cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), ), ), "code": " with context_mgr():\n pass\n", "parser": None, "expected_position": CodeRange((1, 4), (2, 12)), }, # leading_lines { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nwith context_mgr(): pass\n", "parser": parse_statement, "expected_position": CodeRange((2, 0), (2, 24)), }, # Weird spacing rules { "node": cst.With( (cst.WithItem( cst.Call( cst.Name("context_mgr"), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), )), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(""), ), "code": "with(context_mgr()): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 25)), }, # Whitespace { "node": cst.With( (cst.WithItem( cst.Call(cst.Name("context_mgr")), cst.AsName( cst.Name("ctx"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(" "), whitespace_before_colon=cst.SimpleWhitespace(" "), ), "code": "with context_mgr() as ctx : pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 36)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( { "get_node": lambda: cst.With((), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), ))), "expected_re": "A With statement must have at least one WithItem", }, { "get_node": lambda: cst.With( (cst.WithItem( cst.Call(cst.Name("foo")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ") ), ), ), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), ), "expected_re": "The last WithItem in a With cannot have a trailing comma", }, { "get_node": lambda: cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after with keyword", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs) @data_provider(( { "code": "with a, b: pass", "parser": parse_statement_as(python_version="3.1"), "expect_success": True, }, { "code": "with a, b: pass", "parser": parse_statement_as(python_version="3.0"), "expect_success": False, }, )) def test_versions(self, **kwargs: Any) -> None: self.assert_parses(**kwargs)
] # XXX: currently the path matching only takes the first match, probably it should merge all (excluding body) # matches into one dict scope_elem_extra_kwargs = per_path_default_kwargs.get( scope_stack, lambda _: {})(scope_stack) scope_elem_type = only(scope_elem_types) # ambiguity error if not only # XXX: switch to giving TransformCtx CaptureReference's get_name_for_match name = scope_expr.capture.pattern.pattern return scope_elem_type( **{ **({'name': cst.Name(name), } if 'name' in scope_elem_type.__slots__ else {}), **({'body': cst.IndentedBlock(body=[cst.SimpleStatementLine(body=[cst.Pass()])]), } if 'body' in scope_elem_type.__slots__ else {}), **({ 'params': cst.Parameters(), 'decorators': (), **({ 'asynchronous': cst.Asynchronous() } if scope_expr.properties.get('async') else {}) } if scope_elem_type is cst.FunctionDef else {}), **({ 'targets': [cst.AssignTarget(target=cst.Name(name))], 'value': cst.Name("None") } if scope_elem_type is cst.Assign else {}), **scope_elem_extra_kwargs } )
class FooterBehaviorTest(UnitTest): @data_provider({ # Literally the most basic example "simple_module": { "code": "\n", "expected_module": cst.Module(body=()) }, # A module with a header comment "header_only_module": { "code": "# This is a header comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[], ), }, # A module with a header and footer "simple_header_footer_module": { "code": "# This is a header comment\npass\n# This is a footer comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[cst.SimpleStatementLine([cst.Pass()])], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # A module which should have a footer comment taken from the # if statement's indented block. "simple_reparented_footer_module": { "code": "# This is a header comment\nif True:\n pass\n# This is a footer comment\n", "expected_module": cst.Module( header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( header=cst.TrailingWhitespace(), body=[ cst.SimpleStatementLine( body=[cst.Pass()], trailing_whitespace=cst.TrailingWhitespace( ), ) ], ), ) ], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # Verifying that we properly parse and spread out footer comments to the # relative indents they go with. "complex_reparented_footer_module": { "code": ("# This is a header comment\nif True:\n if True:\n pass" + "\n # This is an inner indented block comment\n # This " + "is an outer indented block comment\n# This is a footer comment\n" ), "expected_module": cst.Module( body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.SimpleStatementLine( body=[cst.Pass()]) ], footer=[ cst.EmptyLine(comment=cst.Comment( value= "# This is an inner indented block comment" )) ], ), ) ], footer=[ cst.EmptyLine(comment=cst.Comment( value= "# This is an outer indented block comment" )) ], ), ) ], header=[ cst.EmptyLine(comment=cst.Comment( value="# This is a header comment")) ], footer=[ cst.EmptyLine(comment=cst.Comment( value="# This is a footer comment")) ], ), }, # Verify that comments belonging to statements are still owned even # after an indented block. "statement_comment_reparent": { "code": "if foo:\n return\n# comment\nx = 7\n", "expected_module": cst.Module(body=[ cst.If( test=cst.Name(value="foo"), body=cst.IndentedBlock(body=[ cst.SimpleStatementLine(body=[ cst.Return( whitespace_after_return=cst.SimpleWhitespace( value="")) ]) ]), ), cst.SimpleStatementLine( body=[ cst.Assign( targets=[ cst.AssignTarget(target=cst.Name(value="x")) ], value=cst.Integer(value="7"), ) ], leading_lines=[ cst.EmptyLine(comment=cst.Comment(value="# comment")) ], ), ]), }, # Verify that even if there are completely empty lines, we give all lines # up to and including the last line that's indented correctly. That way # comments that line up with indented block's indentation level aren't # parented to the next line just because there's a blank line or two # between them. "statement_comment_with_empty_lines": { "code": ("def foo():\n if True:\n pass\n\n # Empty " + "line before me\n\n else:\n pass\n"), "expected_module": cst.Module(body=[ cst.FunctionDef( name=cst.Name(value="foo"), params=cst.Parameters(), body=cst.IndentedBlock(body=[ cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[ cst.SimpleStatementLine(body=[cst.Pass()]) ], footer=[ cst.EmptyLine(indent=False), cst.EmptyLine(comment=cst.Comment( value="# Empty line before me")), ], ), orelse=cst.Else( body=cst.IndentedBlock(body=[ cst.SimpleStatementLine(body=[cst.Pass()]) ]), leading_lines=[cst.EmptyLine(indent=False)], ), ) ]), ) ]), }, }) def test_parsers(self, code: str, expected_module: cst.CSTNode) -> None: parsed_module = parse_module(dedent(code)) self.assertTrue( deep_equals(parsed_module, expected_module), msg= f"\n{parsed_module!r}\nis not deeply equal to \n{expected_module!r}", )
def class_def(self, name: str) -> cst.ClassDef: return cst.ClassDef( cst.Name(name), cst.IndentedBlock(list(self.body)), )
class WithTest(CSTNodeTest): maxDiff: int = 2000 @data_provider(( # Simple with block { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with context_mgr(): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 24)), }, # Simple async with block { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), "code": "async with context_mgr(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.7")), }, # Python 3.6 async with block { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n async with context_mgr(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), }, # Multiple context managers { "node": cst.With( ( cst.WithItem(cst.Call(cst.Name("foo"))), cst.WithItem(cst.Call(cst.Name("bar"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with foo(), bar(): pass\n", "parser": None, }, { "node": cst.With( ( cst.WithItem( cst.Call(cst.Name("foo")), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.WithItem(cst.Call(cst.Name("bar"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with foo(), bar(): pass\n", "parser": parse_statement, }, # With block containing variable for context manager. { "node": cst.With( (cst.WithItem( cst.Call(cst.Name("context_mgr")), cst.AsName(cst.Name("ctx")), ), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with context_mgr() as ctx: pass\n", "parser": parse_statement, }, # indentation { "node": DummyIndentedBlock( " ", cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), ), "code": " with context_mgr(): pass\n", "parser": None, "expected_position": CodeRange((1, 4), (1, 28)), }, # with an indented body { "node": DummyIndentedBlock( " ", cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), ), ), "code": " with context_mgr():\n pass\n", "parser": None, "expected_position": CodeRange((1, 4), (2, 12)), }, # leading_lines { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nwith context_mgr(): pass\n", "parser": parse_statement, "expected_position": CodeRange((2, 0), (2, 24)), }, # Whitespace { "node": cst.With( (cst.WithItem( cst.Call(cst.Name("context_mgr")), cst.AsName( cst.Name("ctx"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(" "), whitespace_before_colon=cst.SimpleWhitespace(" "), ), "code": "with context_mgr() as ctx : pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 36)), }, # Weird spacing rules, that parse differently depending on whether # we are using a grammar that included parenthesized with statements. { "node": cst.With( (cst.WithItem( cst.Call( cst.Name("context_mgr"), lpar=() if is_native() else (cst.LeftParen(), ), rpar=() if is_native() else (cst.RightParen(), ), )), ), cst.SimpleStatementSuite((cst.Pass(), )), lpar=(cst.LeftParen() if is_native() else MaybeSentinel.DEFAULT), rpar=(cst.RightParen() if is_native() else MaybeSentinel.DEFAULT), whitespace_after_with=cst.SimpleWhitespace(""), ), "code": "with(context_mgr()): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 25)), }, # Multi-line parenthesized with. { "node": cst.With( ( cst.WithItem( cst.Call(cst.Name("foo")), comma=cst. Comma(whitespace_after=cst.ParenthesizedWhitespace( first_line=cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(value="", ), comment=None, newline=cst.Newline(value=None, ), ), empty_lines=[], indent=True, last_line=cst.SimpleWhitespace(value=" ", ), )), ), cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()), ), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), rpar=cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), "code": ("with ( foo(),\n" " bar(), ): pass\n"), # noqa "parser": parse_statement if is_native() else None, "expected_position": CodeRange((1, 0), (2, 21)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( { "get_node": lambda: cst.With((), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), ))), "expected_re": "A With statement must have at least one WithItem", }, { "get_node": lambda: cst.With( (cst.WithItem( cst.Call(cst.Name("foo")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ") ), ), ), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), ), "expected_re": "The last WithItem in an unparenthesized With cannot " + "have a trailing comma.", }, { "get_node": lambda: cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after with keyword", }, { "get_node": lambda: cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(""), lpar=cst.LeftParen(), ), "expected_re": "Do not mix concrete LeftParen/RightParen with " + "MaybeSentinel", }, { "get_node": lambda: cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(""), rpar=cst.RightParen(), ), "expected_re": "Do not mix concrete LeftParen/RightParen with " + "MaybeSentinel", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs) @data_provider(( { "code": "with a, b: pass", "parser": parse_statement_as(python_version="3.1"), "expect_success": True, }, { "code": "with a, b: pass", "parser": parse_statement_as(python_version="3.0"), "expect_success": False, }, )) def test_versions(self, **kwargs: Any) -> None: if is_native() and not kwargs.get("expect_success", True): self.skipTest("parse errors are disabled for native parser") self.assert_parses(**kwargs) def test_adding_parens(self) -> None: node = cst.With( ( cst.WithItem( cst.Call(cst.Name("foo")), comma=cst.Comma( whitespace_after=cst.ParenthesizedWhitespace(), ), ), cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()), ), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")), ) module = cst.Module([]) self.assertEqual( module.code_for_node(node), ("with ( foo(),\n" "bar(), ): pass\n") # noqa )
class TryTest(CSTNodeTest): @data_provider( ( # Simple try/except block { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), ), ), ), "code": "try: pass\nexcept: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (2, 12)), }, # Try/except with a class { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("Exception"), ), ), ), "code": "try: pass\nexcept Exception: pass\n", "parser": parse_statement, }, # Try/except with a named class { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("Exception"), name=cst.AsName(cst.Name("exc")), ), ), ), "code": "try: pass\nexcept Exception as exc: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (2, 29)), }, # Try/except with multiple clauses { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("TypeError"), name=cst.AsName(cst.Name("e")), ), cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("KeyError"), name=cst.AsName(cst.Name("e")), ), cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), ), ), ), "code": "try: pass\n" + "except TypeError as e: pass\n" + "except KeyError as e: pass\n" + "except: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (4, 12)), }, # Simple try/finally block { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), "code": "try: pass\nfinally: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (2, 13)), }, # Simple try/except/finally block { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), ), ), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), "code": "try: pass\nexcept: pass\nfinally: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (3, 13)), }, # Simple try/except/else block { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), ), ), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), ), "code": "try: pass\nexcept: pass\nelse: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (3, 10)), }, # Simple try/except/else block/finally { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), ), ), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), "code": "try: pass\nexcept: pass\nelse: pass\nfinally: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (4, 13)), }, # Verify whitespace in various locations { "node": cst.Try( leading_lines=(cst.EmptyLine(comment=cst.Comment("# 1")),), body=cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( leading_lines=(cst.EmptyLine(comment=cst.Comment("# 2")),), type=cst.Name("TypeError"), name=cst.AsName( cst.Name("e"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), whitespace_after_except=cst.SimpleWhitespace(" "), whitespace_before_colon=cst.SimpleWhitespace(" "), body=cst.SimpleStatementSuite((cst.Pass(),)), ), ), orelse=cst.Else( leading_lines=(cst.EmptyLine(comment=cst.Comment("# 3")),), body=cst.SimpleStatementSuite((cst.Pass(),)), whitespace_before_colon=cst.SimpleWhitespace(" "), ), finalbody=cst.Finally( leading_lines=(cst.EmptyLine(comment=cst.Comment("# 4")),), body=cst.SimpleStatementSuite((cst.Pass(),)), whitespace_before_colon=cst.SimpleWhitespace(" "), ), whitespace_before_colon=cst.SimpleWhitespace(" "), ), "code": "# 1\ntry : pass\n# 2\nexcept TypeError as e : pass\n# 3\nelse : pass\n# 4\nfinally : pass\n", "parser": parse_statement, "expected_position": CodeRange((2, 0), (8, 14)), }, # Please don't write code like this { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("TypeError"), name=cst.AsName(cst.Name("e")), ), cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("KeyError"), name=cst.AsName(cst.Name("e")), ), cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), ), ), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), "code": "try: pass\n" + "except TypeError as e: pass\n" + "except KeyError as e: pass\n" + "except: pass\n" + "else: pass\n" + "finally: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (6, 13)), }, # Verify indentation { "node": DummyIndentedBlock( " ", cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("TypeError"), name=cst.AsName(cst.Name("e")), ), cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("KeyError"), name=cst.AsName(cst.Name("e")), ), cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), ), ), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), ), "code": " try: pass\n" + " except TypeError as e: pass\n" + " except KeyError as e: pass\n" + " except: pass\n" + " else: pass\n" + " finally: pass\n", "parser": None, }, # Verify indentation in bodies { "node": DummyIndentedBlock( " ", cst.Try( cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), handlers=( cst.ExceptHandler( cst.IndentedBlock( (cst.SimpleStatementLine((cst.Pass(),)),) ), whitespace_after_except=cst.SimpleWhitespace(""), ), ), orelse=cst.Else( cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)) ), finalbody=cst.Finally( cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)) ), ), ), "code": " try:\n" + " pass\n" + " except:\n" + " pass\n" + " else:\n" + " pass\n" + " finally:\n" + " pass\n", "parser": None, }, # No space when using grouping parens { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), type=cst.Name( "Exception", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),), ), ), ), ), "code": "try: pass\nexcept(Exception): pass\n", "parser": parse_statement, }, # No space when using tuple { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), type=cst.Tuple( [ cst.Element( cst.Name("IOError"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ") ), ), cst.Element(cst.Name("ImportError")), ] ), ), ), ), "code": "try: pass\nexcept(IOError, ImportError): pass\n", "parser": parse_statement, }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider( ( { "get_node": lambda: cst.AsName(cst.Name("")), "expected_re": "empty name identifier", }, { "get_node": lambda: cst.AsName( cst.Name("bla"), whitespace_after_as=cst.SimpleWhitespace("") ), "expected_re": "between 'as'", }, { "get_node": lambda: cst.AsName( cst.Name("bla"), whitespace_before_as=cst.SimpleWhitespace("") ), "expected_re": "before 'as'", }, { "get_node": lambda: cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), name=cst.AsName(cst.Name("bla")), ), "expected_re": "name for an empty type", }, { "get_node": lambda: cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("TypeError"), whitespace_after_except=cst.SimpleWhitespace(""), ), "expected_re": "at least one space after except", }, { "get_node": lambda: cst.Try(cst.SimpleStatementSuite((cst.Pass(),))), "expected_re": "at least one ExceptHandler or Finally", }, { "get_node": lambda: cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), "expected_re": "at least one ExceptHandler in order to have an Else", }, ) ) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
class IndentedBlockTest(CSTNodeTest): @data_provider(( # Standard render ( cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), "\n pass\n", None, ), # Render with empty (cst.IndentedBlock(()), "\n pass\n", None), # Render with empty subnodes (cst.IndentedBlock((cst.SimpleStatementLine( ()), )), "\n pass\n", None), # Test render with custom indent ( cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), ), indent="\t"), "\n\tpass\n", None, ), # Test comments ( cst.IndentedBlock( (cst.SimpleStatementLine((cst.Pass(), )), ), header=cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment("# header comment"), ), ), " # header comment\n pass\n", None, ), ( cst.IndentedBlock( (cst.SimpleStatementLine((cst.Pass(), )), ), footer=(cst.EmptyLine( comment=cst.Comment("# footer comment")), ), ), "\n pass\n # footer comment\n", None, ), ( cst.IndentedBlock( (cst.SimpleStatementLine((cst.Pass(), )), ), footer=(cst.EmptyLine( whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment("# footer comment"), ), ), ), "\n pass\n # footer comment\n", None, ), ( cst.IndentedBlock(( cst.SimpleStatementLine((cst.Continue(), )), cst.SimpleStatementLine((cst.Pass(), )), )), "\n continue\n pass\n", None, ), # Basic parsing test ( cst.If( cst.Name("conditional"), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), ), "if conditional:\n pass\n", parse_statement, ), # Multi-level parsing test ( cst.If( cst.Name("conditional"), cst.IndentedBlock(( cst.SimpleStatementLine((cst.Pass(), )), cst.If( cst.Name("other_conditional"), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), ), )), ), "if conditional:\n pass\n if other_conditional:\n pass\n", parse_statement, ), # Inconsistent indentation parsing test ( cst.If( cst.Name("conditional"), cst.IndentedBlock(( cst.SimpleStatementLine((cst.Pass(), )), cst.If( cst.Name("other_conditional"), cst.IndentedBlock( (cst.SimpleStatementLine((cst.Pass(), )), ), indent=" ", ), ), )), ), "if conditional:\n pass\n if other_conditional:\n pass\n", parse_statement, ), )) def test_valid( self, node: cst.CSTNode, code: str, parser: Optional[Callable[[str], cst.CSTNode]], ) -> None: self.validate_node(node, code, parser) @data_provider(( ( lambda: cst.IndentedBlock( (cst.SimpleStatementLine((cst.Pass(), )), ), indent=""), "non-zero width indent", ), ( lambda: cst.IndentedBlock( (cst.SimpleStatementLine((cst.Pass(), )), ), indent="this isn't valid whitespace!", ), "only whitespace", ), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
class SimpleCompTest(CSTNodeTest): @data_provider([ # simple GeneratorExp { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c"))), "code": "(a for b in c)", "parser": parse_expression, "expected_position": CodeRange((1, 1), (1, 13)), }, # simple ListComp { "node": cst.ListComp(cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c"))), "code": "[a for b in c]", "parser": parse_expression, "expected_position": CodeRange((1, 0), (1, 14)), }, # simple SetComp { "node": cst.SetComp(cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c"))), "code": "{a for b in c}", "parser": parse_expression, }, # async GeneratorExp { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), asynchronous=cst.Asynchronous(), ), ), "code": "(a async for b in c)", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), }, # Python 3.6 async GeneratorExp { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine((cst.Expr( cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), asynchronous=cst.Asynchronous(), ), )), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n (a async for b in c)\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), }, # a generator doesn't have to own it's own parenthesis { "node": cst.Call( cst.Name("func"), [ cst.Arg( cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[], rpar=[], )) ], ), "code": "func(a for b in c)", "parser": parse_expression, }, # add a few 'if' clauses { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf(cst.Name("d")), cst.CompIf(cst.Name("e")), cst.CompIf(cst.Name("f")), ], ), ), "code": "(a for b in c if d if e if f)", "parser": parse_expression, "expected_position": CodeRange((1, 1), (1, 28)), }, # nested/inner for-in clause { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), inner_for_in=cst.CompFor(target=cst.Name("d"), iter=cst.Name("e")), ), ), "code": "(a for b in c for d in e)", "parser": parse_expression, }, # nested/inner for-in clause with an 'if' clause { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[cst.CompIf(cst.Name("d"))], inner_for_in=cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ), "code": "(a for b in c if d for e in f)", "parser": parse_expression, }, # custom whitespace { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf( cst.Name("d"), whitespace_before=cst.SimpleWhitespace("\t"), whitespace_before_test=cst.SimpleWhitespace( "\t\t"), ) ], whitespace_before=cst.SimpleWhitespace(" "), whitespace_after_for=cst.SimpleWhitespace(" "), whitespace_before_in=cst.SimpleWhitespace(" "), whitespace_after_in=cst.SimpleWhitespace(" "), ), lpar=[ cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f")) ], rpar=[ cst.RightParen( whitespace_before=cst.SimpleWhitespace("\f\f")) ], ), "code": "(\fa for b in c\tif\t\td\f\f)", "parser": parse_expression, "expected_position": CodeRange((1, 2), (1, 30)), }, # custom whitespace around ListComp's brackets { "node": cst.ListComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace("\t")), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace("\t\t")), lpar=[ cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f")) ], rpar=[ cst.RightParen( whitespace_before=cst.SimpleWhitespace("\f\f")) ], ), "code": "(\f[\ta for b in c\t\t]\f\f)", "parser": parse_expression, "expected_position": CodeRange((1, 2), (1, 19)), }, # custom whitespace around SetComp's braces { "node": cst.SetComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lbrace=cst.LeftCurlyBrace( whitespace_after=cst.SimpleWhitespace("\t")), rbrace=cst.RightCurlyBrace( whitespace_before=cst.SimpleWhitespace("\t\t")), lpar=[ cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f")) ], rpar=[ cst.RightParen( whitespace_before=cst.SimpleWhitespace("\f\f")) ], ), "code": "(\f{\ta for b in c\t\t}\f\f)", "parser": parse_expression, }, # no whitespace between elements { "node": cst.GeneratorExp( cst.Name("a", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), cst.CompFor( target=cst.Name("b", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), iter=cst.Name("c", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), ifs=[ cst.CompIf( cst.Name("d", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), whitespace_before=cst.SimpleWhitespace(""), whitespace_before_test=cst.SimpleWhitespace(""), ) ], inner_for_in=cst.CompFor( target=cst.Name("e", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), iter=cst.Name("f", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), whitespace_before=cst.SimpleWhitespace(""), whitespace_after_for=cst.SimpleWhitespace(""), whitespace_before_in=cst.SimpleWhitespace(""), whitespace_after_in=cst.SimpleWhitespace(""), ), whitespace_before=cst.SimpleWhitespace(""), whitespace_after_for=cst.SimpleWhitespace(""), whitespace_before_in=cst.SimpleWhitespace(""), whitespace_after_in=cst.SimpleWhitespace(""), ), lpar=[cst.LeftParen()], rpar=[cst.RightParen()], ), "code": "((a)for(b)in(c)if(d)for(e)in(f))", "parser": parse_expression, "expected_position": CodeRange((1, 1), (1, 31)), }, # no whitespace before/after GeneratorExp is valid { "node": cst.Comparison( cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), ), [ cst.ComparisonTarget( cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), cst.GeneratorExp( cst.Name("d"), cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ) ], ), "code": "(a for b in c)is(d for e in f)", "parser": parse_expression, }, # no whitespace before/after ListComp is valid { "node": cst.Comparison( cst.ListComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), ), [ cst.ComparisonTarget( cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), cst.ListComp( cst.Name("d"), cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ) ], ), "code": "[a for b in c]is[d for e in f]", "parser": parse_expression, }, # no whitespace before/after SetComp is valid { "node": cst.Comparison( cst.SetComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), ), [ cst.ComparisonTarget( cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), cst.SetComp( cst.Name("d"), cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ) ], ), "code": "{a for b in c}is{d for e in f}", "parser": parse_expression, }, ]) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), ( lambda: cst.ListComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), ( lambda: cst.SetComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_before=cst.SimpleWhitespace(""), ), ), "Must have at least one space before 'for' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), asynchronous=cst.Asynchronous(), whitespace_before=cst.SimpleWhitespace(""), ), ), "Must have at least one space before 'async' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_after_for=cst.SimpleWhitespace(""), ), ), "Must have at least one space after 'for' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_before_in=cst.SimpleWhitespace(""), ), ), "Must have at least one space before 'in' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_after_in=cst.SimpleWhitespace(""), ), ), "Must have at least one space after 'in' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf( cst.Name("d"), whitespace_before=cst.SimpleWhitespace(""), ) ], ), ), "Must have at least one space before 'if' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf( cst.Name("d"), whitespace_before_test=cst.SimpleWhitespace(""), ) ], ), ), "Must have at least one space after 'if' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), inner_for_in=cst.CompFor( target=cst.Name("d"), iter=cst.Name("e"), whitespace_before=cst.SimpleWhitespace(""), ), ), ), "Must have at least one space before 'for' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), inner_for_in=cst.CompFor( target=cst.Name("d"), iter=cst.Name("e"), asynchronous=cst.Asynchronous(), whitespace_before=cst.SimpleWhitespace(""), ), ), ), "Must have at least one space before 'async' keyword.", ), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
class WhileTest(CSTNodeTest): @data_provider(( # Simple while block # pyre-fixme[6]: Incompatible parameter type { "node": cst.While(cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), ))), "code": "while iter(): pass\n", "parser": parse_statement, }, # While block with else { "node": cst.While( cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), cst.Else(cst.SimpleStatementSuite((cst.Pass(), ))), ), "code": "while iter(): pass\nelse: pass\n", "parser": parse_statement, }, # indentation { "node": DummyIndentedBlock( " ", cst.While( cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), ), ), "code": " while iter(): pass\n", "parser": None, "expected_position": CodeRange((1, 4), (1, 22)), }, # while an indented body { "node": DummyIndentedBlock( " ", cst.While( cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), ), ), "code": " while iter():\n pass\n", "parser": None, "expected_position": CodeRange((1, 4), (2, 12)), }, # leading_lines { "node": cst.While( cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nwhile iter():\n pass\n", "parser": parse_statement, "expected_position": CodeRange((2, 0), (3, 8)), }, { "node": cst.While( cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), cst.Else( cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# else comment")), ), ), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nwhile iter():\n pass\n# else comment\nelse:\n pass\n", "parser": None, "expected_position": CodeRange((2, 0), (6, 8)), }, # Weird spacing rules { "node": cst.While( cst.Call( cst.Name("iter"), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_while=cst.SimpleWhitespace(""), ), "code": "while(iter()): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 19)), }, # Whitespace { "node": cst.While( cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_while=cst.SimpleWhitespace(" "), whitespace_before_colon=cst.SimpleWhitespace(" "), ), "code": "while iter() : pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 21)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(({ "get_node": lambda: cst.While( cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_while=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after 'while' keyword", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
def compile_graph(graph: GraphicalModel, namespace: dict, fn_name): """Compile MCX's graph into a python (executable) function.""" # Model arguments are passed in the following order: # 1. (samplers only) rng_key; # 2. (logpdf only) random variables, in the order in which they appear in the model. # 3. (all) the model's arguments and keyword arguments. maybe_rng_key = [ compile_placeholder(node, graph) for node in graph.placeholders if node.name == "rng_key" ] maybe_random_variables = [ compile_placeholder(node, graph) for node in reversed(list(graph.placeholders)) if node.is_random_variable ] model_args = [ compile_placeholder(node, graph) for node in graph.placeholders if not node.is_random_variable and node.name != "rng_key" and not node.has_default ] model_kwargs = [ compile_placeholder(node, graph) for node in graph.placeholders if not node.is_random_variable and node.name != "rng_key" and node.has_default ] args = maybe_rng_key + model_args + maybe_random_variables + model_kwargs # Every statement in the function corresponds to either a constant definition or # a variable assignment. We use a topological sort to respect the # dependency order. stmts = [] returns = [] for node in nx.topological_sort(graph): if node.name is None: continue if isinstance(node, Constant): stmt = cst.SimpleStatementLine(body=[ cst.Assign( targets=[ cst.AssignTarget(target=cst.Name(value=node.name)) ], value=node.cst_generator(), ) ]) stmts.append(stmt) if isinstance(node, Op): stmt = cst.SimpleStatementLine(body=[ cst.Assign( targets=[ cst.AssignTarget(target=cst.Name(value=node.name)) ], value=compile_op(node, graph), ) ]) stmts.append(stmt) if node.is_returned: returns.append( cst.SimpleStatementLine( body=[cst.Return(value=cst.Name(value=node.name))])) # Assemble the function's CST using the previously translated nodes. ast_fn = cst.Module(body=[ cst.FunctionDef( name=cst.Name(value=fn_name), params=cst.Parameters(params=args), body=cst.IndentedBlock(body=stmts + returns), ) ]) code = ast_fn.code exec(code, namespace) fn = namespace[fn_name] return fn, code