class ExperimentalReentrantCodegenProviderTest(UnitTest): @data_provider({ "simple_top_level_statement": { "old_module": ("""\ import math c = math.sqrt(a*a + b*b) """), "new_module": ("""\ import math c = math.hypot(a, b) """), "old_node": lambda m: m.body[1], "new_node": cst.parse_statement("c = math.hypot(a, b)"), }, "replacement_inside_block": { "old_module": ("""\ import math def do_math(a, b): c = math.sqrt(a*a + b*b) return c """), "new_module": ("""\ import math def do_math(a, b): c = math.hypot(a, b) return c """), "old_node": lambda m: m.body[1].body.body[0], "new_node": cst.parse_statement("c = math.hypot(a, b)"), }, "missing_trailing_newline": { "old_module": "old_fn()", # this module has no trailing newline "new_module": "new_fn()", "old_node": lambda m: m.body[0], "new_node": cst.parse_statement("new_fn()\n"), }, "nested_blocks_with_missing_trailing_newline": { "old_module": ( """\ if outer: if inner: old_fn()""" # this module has no trailing newline ), "new_module": ("""\ if outer: if inner: new_fn()"""), "old_node": lambda m: m.body[0].body.body[0].body.body[0], "new_node": cst.parse_statement("new_fn()\n"), }, }) def test_provider( self, old_module: str, new_module: str, old_node: Callable[[cst.Module], cst.CSTNode], new_node: cst.BaseStatement, ) -> None: old_module = dedent(old_module) new_module = dedent(new_module) mw = MetadataWrapper(cst.parse_module(old_module)) codegen_partial = mw.resolve(ExperimentalReentrantCodegenProvider)[ old_node(mw.module)] self.assertEqual(codegen_partial.get_original_module_code(), old_module) self.assertEqual(codegen_partial.get_modified_module_code(new_node), new_module) def test_byte_conversion(self, ) -> None: module_bytes = "fn()\n".encode("utf-16") mw = MetadataWrapper( cst.parse_module("fn()\n", cst.PartialParserConfig(encoding="utf-16"))) codegen_partial = mw.resolve(ExperimentalReentrantCodegenProvider)[ mw.module.body[0]] self.assertEqual(codegen_partial.get_original_module_bytes(), module_bytes) self.assertEqual( codegen_partial.get_modified_module_bytes( cst.parse_statement("fn2()\n")), "fn2()\n".encode("utf-16"), )
def leave_Module(self, original_node: libcst.Module, updated_node: libcst.Module) -> libcst.Module: # Don't try to modify if we have nothing to do if (not self.module_imports and not self.module_mapping and not self.module_aliases and not self.alias_mapping): return updated_node # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) # Mapping of modules we're adding to the object with and without alias they should import module_and_alias_mapping = defaultdict(list) for module, aliases in self.alias_mapping.items(): module_and_alias_mapping[module].extend(aliases) for module, imports in self.module_mapping.items(): module_and_alias_mapping[module].extend([(object, None) for object in imports]) module_and_alias_mapping = { module: sorted(aliases) for module, aliases in module_and_alias_mapping.items() } # import ptvsd; ptvsd.set_trace() # Now, add all of the imports we need! return updated_node.with_changes(body=( *[ parse_statement( f"from {module} import " + ", ".join([ obj if alias is None else f"{obj} as {alias}" for (obj, alias) in aliases ]), config=updated_node.config_for_parsing, ) for module, aliases in module_and_alias_mapping.items() if module == "__future__" ], *statements_before_imports, *[ parse_statement(f"import {module}", config=updated_node.config_for_parsing) for module in sorted(self.module_imports) ], *[ parse_statement( f"import {module} as {asname}", config=updated_node.config_for_parsing, ) for (module, asname) in self.module_aliases.items() ], *[ parse_statement( f"from {module} import " + ", ".join([ obj if alias is None else f"{obj} as {alias}" for (obj, alias) in aliases ]), config=updated_node.config_for_parsing, ) for module, aliases in module_and_alias_mapping.items() if module != "__future__" ], *statements_after_imports, ))
class ParseErrorsTest(UnitTest): @data_provider({ # _wrapped_tokenize raises these exceptions "wrapped_tokenize__invalid_token": ( lambda: cst.parse_module("'"), dedent(""" Syntax Error @ 1:1. "'" is not a valid token. ' ^ """).strip(), ), "wrapped_tokenize__expected_dedent": ( lambda: cst.parse_module("if False:\n pass\n pass"), dedent(""" Syntax Error @ 3:1. Inconsistent indentation. Expected a dedent. pass ^ """).strip(), ), "wrapped_tokenize__mismatched_braces": ( lambda: cst.parse_module("abcd)"), dedent(""" Syntax Error @ 1:5. Encountered a closing brace without a matching opening brace. abcd) ^ """).strip(), ), # _base_parser raises these exceptions "base_parser__unexpected_indent": ( lambda: cst.parse_module(" abcd"), dedent(""" Syntax Error @ 1:5. Incomplete input. Unexpectedly encountered an indent. abcd ^ """).strip(), ), "base_parser__unexpected_dedent": ( lambda: cst.parse_module("if False:\n (el for el\n"), dedent(""" Syntax Error @ 3:1. Incomplete input. Encountered a dedent, but expected 'in'. (el for el ^ """).strip(), ), "base_parser__multiple_possibilities": ( lambda: cst.parse_module("try: pass"), dedent(""" Syntax Error @ 2:1. Incomplete input. Encountered end of file (EOF), but expected 'except', or 'finally'. try: pass ^ """).strip(), ), # conversion functions raise these exceptions. # `_base_parser` is responsible for attaching location information. "convert_nonterminal__dict_unpacking": ( lambda: cst.parse_expression("{**el for el in []}"), dedent(""" Syntax Error @ 1:19. dict unpacking cannot be used in dict comprehension {**el for el in []} ^ """).strip(), ), "convert_nonterminal__arglist_non_default_after_default": ( lambda: cst.parse_statement("def fn(first=None, second): ..."), dedent(""" Syntax Error @ 1:26. Cannot have a non-default argument following a default argument. def fn(first=None, second): ... ^ """).strip(), ), "convert_nonterminal__arglist_trailing_param_star_without_comma": ( lambda: cst.parse_statement("def fn(abc, *): ..."), dedent(""" Syntax Error @ 1:14. Named (keyword) arguments must follow a bare *. def fn(abc, *): ... ^ """).strip(), ), "convert_nonterminal__arglist_trailing_param_star_with_comma": ( lambda: cst.parse_statement("def fn(abc, *,): ..."), dedent(""" Syntax Error @ 1:15. Named (keyword) arguments must follow a bare *. def fn(abc, *,): ... ^ """).strip(), ), "convert_nonterminal__class_arg_positional_after_keyword": ( lambda: cst.parse_statement("class Cls(first=None, second): ..."), dedent(""" Syntax Error @ 2:1. Positional argument follows keyword argument. class Cls(first=None, second): ... ^ """).strip(), ), "convert_nonterminal__class_arg_positional_expansion_after_keyword": ( lambda: cst.parse_statement("class Cls(first=None, *second): ..."), dedent(""" Syntax Error @ 2:1. Positional argument follows keyword argument. class Cls(first=None, *second): ... ^ """).strip(), ), }) def test_parser_syntax_error_str(self, parse_fn: Callable[[], object], expected: str) -> None: with self.assertRaises(cst.ParserSyntaxError) as cm: parse_fn() if not is_native(): self.assertEqual(str(cm.exception), expected) def test_native_fallible_into_py(self) -> None: with patch( "libcst._nodes.expression.Name._validate") as await_validate: await_validate.side_effect = CSTValidationError( "validate is broken") with self.assertRaises(Exception) as e: cst.parse_module("foo") self.assertIsInstance(e.exception, (SyntaxError, cst.ParserSyntaxError))
class ForTest(CSTNodeTest): @data_provider(( # Simple for block { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "for target in iter(): pass\n", "parser": parse_statement, }, # Simple async for block { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), "code": "async for target in iter(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.7")), }, # Python 3.6 async for block { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n async for target in iter(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), }, # For block with else { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), cst.Else(cst.SimpleStatementSuite((cst.Pass(), ))), ), "code": "for target in iter(): pass\nelse: pass\n", "parser": parse_statement, }, # indentation { "node": DummyIndentedBlock( " ", cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), ), ), "code": " for target in iter(): pass\n", "parser": None, }, # for an indented body { "node": DummyIndentedBlock( " ", cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), ), ), "code": " for target in iter():\n pass\n", "parser": None, "expected_position": CodeRange((1, 4), (2, 12)), }, # leading_lines { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), cst.Else( cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# else comment")), ), ), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nfor target in iter():\n pass\n# else comment\nelse:\n pass\n", "parser": None, "expected_position": CodeRange((2, 0), (6, 8)), }, # Weird spacing rules { "node": cst.For( cst.Name("target", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), cst.Call( cst.Name("iter"), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_for=cst.SimpleWhitespace(""), whitespace_before_in=cst.SimpleWhitespace(""), whitespace_after_in=cst.SimpleWhitespace(""), ), "code": "for(target)in(iter()): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 27)), }, # Whitespace { "node": cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_for=cst.SimpleWhitespace(" "), whitespace_before_in=cst.SimpleWhitespace(" "), whitespace_after_in=cst.SimpleWhitespace(" "), whitespace_before_colon=cst.SimpleWhitespace(" "), ), "code": "for target in iter() : pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 31)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( { "get_node": lambda: cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_for=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after 'for' keyword", }, { "get_node": lambda: cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_before_in=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space before 'in' keyword", }, { "get_node": lambda: cst.For( cst.Name("target"), cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_in=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after 'in' keyword", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
def test_valid(self, **kwargs: Any) -> None: self.validate_node( parser=lambda code: ensure_type(parse_statement(code), cst. SimpleStatementLine).body[0], **kwargs, )
def inner(code: str) -> cst.BaseStatement: return cst.parse_statement(code, config=cst.PartialParserConfig(**config))
def _assert_parser(code: str) -> cst.Assert: return ensure_type( ensure_type(parse_statement(code), cst.SimpleStatementLine).body[0], cst.Assert)
class SimpleCompTest(CSTNodeTest): @data_provider([ # simple GeneratorExp { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c"))), "code": "(a for b in c)", "parser": parse_expression, "expected_position": CodeRange((1, 1), (1, 13)), }, # simple ListComp { "node": cst.ListComp(cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c"))), "code": "[a for b in c]", "parser": parse_expression, "expected_position": CodeRange((1, 0), (1, 14)), }, # simple SetComp { "node": cst.SetComp(cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c"))), "code": "{a for b in c}", "parser": parse_expression, }, # async GeneratorExp { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), asynchronous=cst.Asynchronous(), ), ), "code": "(a async for b in c)", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), }, # Python 3.6 async GeneratorExp { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine((cst.Expr( cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), asynchronous=cst.Asynchronous(), ), )), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n (a async for b in c)\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), }, # a generator doesn't have to own it's own parenthesis { "node": cst.Call( cst.Name("func"), [ cst.Arg( cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[], rpar=[], )) ], ), "code": "func(a for b in c)", "parser": parse_expression, }, # add a few 'if' clauses { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf(cst.Name("d")), cst.CompIf(cst.Name("e")), cst.CompIf(cst.Name("f")), ], ), ), "code": "(a for b in c if d if e if f)", "parser": parse_expression, "expected_position": CodeRange((1, 1), (1, 28)), }, # nested/inner for-in clause { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), inner_for_in=cst.CompFor(target=cst.Name("d"), iter=cst.Name("e")), ), ), "code": "(a for b in c for d in e)", "parser": parse_expression, }, # nested/inner for-in clause with an 'if' clause { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[cst.CompIf(cst.Name("d"))], inner_for_in=cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ), "code": "(a for b in c if d for e in f)", "parser": parse_expression, }, # custom whitespace { "node": cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf( cst.Name("d"), whitespace_before=cst.SimpleWhitespace("\t"), whitespace_before_test=cst.SimpleWhitespace( "\t\t"), ) ], whitespace_before=cst.SimpleWhitespace(" "), whitespace_after_for=cst.SimpleWhitespace(" "), whitespace_before_in=cst.SimpleWhitespace(" "), whitespace_after_in=cst.SimpleWhitespace(" "), ), lpar=[ cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f")) ], rpar=[ cst.RightParen( whitespace_before=cst.SimpleWhitespace("\f\f")) ], ), "code": "(\fa for b in c\tif\t\td\f\f)", "parser": parse_expression, "expected_position": CodeRange((1, 2), (1, 30)), }, # custom whitespace around ListComp's brackets { "node": cst.ListComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace("\t")), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace("\t\t")), lpar=[ cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f")) ], rpar=[ cst.RightParen( whitespace_before=cst.SimpleWhitespace("\f\f")) ], ), "code": "(\f[\ta for b in c\t\t]\f\f)", "parser": parse_expression, "expected_position": CodeRange((1, 2), (1, 19)), }, # custom whitespace around SetComp's braces { "node": cst.SetComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lbrace=cst.LeftCurlyBrace( whitespace_after=cst.SimpleWhitespace("\t")), rbrace=cst.RightCurlyBrace( whitespace_before=cst.SimpleWhitespace("\t\t")), lpar=[ cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f")) ], rpar=[ cst.RightParen( whitespace_before=cst.SimpleWhitespace("\f\f")) ], ), "code": "(\f{\ta for b in c\t\t}\f\f)", "parser": parse_expression, }, # no whitespace between elements { "node": cst.GeneratorExp( cst.Name("a", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), cst.CompFor( target=cst.Name("b", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), iter=cst.Name("c", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), ifs=[ cst.CompIf( cst.Name("d", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), whitespace_before=cst.SimpleWhitespace(""), whitespace_before_test=cst.SimpleWhitespace(""), ) ], inner_for_in=cst.CompFor( target=cst.Name("e", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), iter=cst.Name("f", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]), whitespace_before=cst.SimpleWhitespace(""), whitespace_after_for=cst.SimpleWhitespace(""), whitespace_before_in=cst.SimpleWhitespace(""), whitespace_after_in=cst.SimpleWhitespace(""), ), whitespace_before=cst.SimpleWhitespace(""), whitespace_after_for=cst.SimpleWhitespace(""), whitespace_before_in=cst.SimpleWhitespace(""), whitespace_after_in=cst.SimpleWhitespace(""), ), lpar=[cst.LeftParen()], rpar=[cst.RightParen()], ), "code": "((a)for(b)in(c)if(d)for(e)in(f))", "parser": parse_expression, "expected_position": CodeRange((1, 1), (1, 31)), }, # no whitespace before/after GeneratorExp is valid { "node": cst.Comparison( cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), ), [ cst.ComparisonTarget( cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), cst.GeneratorExp( cst.Name("d"), cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ) ], ), "code": "(a for b in c)is(d for e in f)", "parser": parse_expression, }, # no whitespace before/after ListComp is valid { "node": cst.Comparison( cst.ListComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), ), [ cst.ComparisonTarget( cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), cst.ListComp( cst.Name("d"), cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ) ], ), "code": "[a for b in c]is[d for e in f]", "parser": parse_expression, }, # no whitespace before/after SetComp is valid { "node": cst.Comparison( cst.SetComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), ), [ cst.ComparisonTarget( cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), cst.SetComp( cst.Name("d"), cst.CompFor(target=cst.Name("e"), iter=cst.Name("f")), ), ) ], ), "code": "{a for b in c}is{d for e in f}", "parser": parse_expression, }, ]) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), ( lambda: cst.ListComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), ( lambda: cst.SetComp( cst.Name("a"), cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")), lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_before=cst.SimpleWhitespace(""), ), ), "Must have at least one space before 'for' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), asynchronous=cst.Asynchronous(), whitespace_before=cst.SimpleWhitespace(""), ), ), "Must have at least one space before 'async' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_after_for=cst.SimpleWhitespace(""), ), ), "Must have at least one space after 'for' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_before_in=cst.SimpleWhitespace(""), ), ), "Must have at least one space before 'in' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), whitespace_after_in=cst.SimpleWhitespace(""), ), ), "Must have at least one space after 'in' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf( cst.Name("d"), whitespace_before=cst.SimpleWhitespace(""), ) ], ), ), "Must have at least one space before 'if' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), ifs=[ cst.CompIf( cst.Name("d"), whitespace_before_test=cst.SimpleWhitespace(""), ) ], ), ), "Must have at least one space after 'if' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), inner_for_in=cst.CompFor( target=cst.Name("d"), iter=cst.Name("e"), whitespace_before=cst.SimpleWhitespace(""), ), ), ), "Must have at least one space before 'for' keyword.", ), ( lambda: cst.GeneratorExp( cst.Name("a"), cst.CompFor( target=cst.Name("b"), iter=cst.Name("c"), inner_for_in=cst.CompFor( target=cst.Name("d"), iter=cst.Name("e"), asynchronous=cst.Asynchronous(), whitespace_before=cst.SimpleWhitespace(""), ), ), ), "Must have at least one space before 'async' keyword.", ), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
class AwaitTest(CSTNodeTest): @data_provider(( # Some simple calls { "node": cst.Await(cst.Name("test")), "code": "await test", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), "expected_position": None, }, { "node": cst.Await(cst.Call(cst.Name("test"))), "code": "await test()", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), "expected_position": None, }, # Whitespace { "node": cst.Await( cst.Name("test"), whitespace_after_await=cst.SimpleWhitespace(" "), lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "code": "( await test )", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), "expected_position": CodeRange((1, 2), (1, 13)), }, )) def test_valid_py37(self, **kwargs: Any) -> None: # We don't have sentinel nodes for atoms, so we know that 100% of atoms # can be parsed identically to their creation. self.validate_node(**kwargs) @data_provider(( # Some simple calls { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Expr(cst.Await(cst.Name("test"))), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n await test\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), "expected_position": None, }, { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Expr(cst.Await(cst.Call(cst.Name("test")))), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n await test()\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), "expected_position": None, }, # Whitespace { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine((cst.Expr( cst.Await( cst.Name("test"), whitespace_after_await=cst.SimpleWhitespace(" "), lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), )), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n ( await test )\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), "expected_position": None, }, )) def test_valid_py36(self, **kwargs: Any) -> None: # We don't have sentinel nodes for atoms, so we know that 100% of atoms # can be parsed identically to their creation. self.validate_node(**kwargs) @data_provider(( # Expression wrapping parenthesis rules { "get_node": (lambda: cst.Await(cst.Name("foo"), lpar=(cst.LeftParen(), ))), "expected_re": "left paren without right paren", }, { "get_node": (lambda: cst.Await(cst.Name("foo"), rpar=(cst.RightParen(), ))), "expected_re": "right paren without left paren", }, { "get_node": (lambda: cst.Await(cst.Name("foo"), whitespace_after_await=cst.SimpleWhitespace("")) ), "expected_re": "at least one space after await", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
def runtest(self): code, query, expected = self.spec node = libcst.parse_statement(code).body[0] assert m.matches(node, create_matcher(query)) == expected