def leave_Module( self, original_node: cst.Module, updated_node: cst.Module, ) -> cst.Module: fresh_class_definitions = [ definition for name, definition in self.annotations.class_definitions.items() if name not in self.visited_classes ] # NOTE: The entire change will also be abandoned if # self.annotation_counts is all 0s, so if adding any new category make # sure to record it there. if not (self.toplevel_annotations or fresh_class_definitions or self.annotations.typevars): return updated_node toplevel_statements = [] # First, find the insertion point for imports statements_before_imports, statements_after_imports = self._split_module( original_node, updated_node) # Make sure there's at least one empty line before the first non-import statements_after_imports = self._insert_empty_line( statements_after_imports) for name, annotation in self.toplevel_annotations.items(): annotated_assign = self._apply_annotation_to_attribute_or_global( name=name, annotation=annotation, value=None, ) toplevel_statements.append( cst.SimpleStatementLine([annotated_assign])) # TypeVar definitions could be scattered through the file, so do not # attempt to put new ones with existing ones, just add them at the top. typevars = { k: v for k, v in self.annotations.typevars.items() if k not in self.typevars } if typevars: for var, stmt in typevars.items(): toplevel_statements.append(cst.Newline()) toplevel_statements.append(stmt) self.annotation_counts.typevars_and_generics_added += 1 toplevel_statements.append(cst.Newline()) self.annotation_counts.classes_added = len(fresh_class_definitions) toplevel_statements.extend(fresh_class_definitions) return updated_node.with_changes(body=[ *statements_before_imports, *toplevel_statements, *statements_after_imports, ])
def _create_empty_line(): return cst.EmptyLine( indent=True, whitespace=cst.SimpleWhitespace(value='', ), comment=None, newline=cst.Newline(value=None, ), )
class TrailingWhitespaceTest(CSTNodeTest): @data_provider( ( (cst.TrailingWhitespace(), "\n"), (cst.TrailingWhitespace(whitespace=cst.SimpleWhitespace(" ")), " \n"), (cst.TrailingWhitespace(comment=cst.Comment("# comment")), "# comment\n"), (cst.TrailingWhitespace(newline=cst.Newline("\r\n")), "\r\n"), ( cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment("# comment"), newline=cst.Newline("\r\n"), ), " # comment\r\n", ), ) ) def test_valid(self, node: cst.CSTNode, code: str) -> None: self.validate_node(node, code)
class EmptyLineTest(CSTNodeTest): @data_provider(( (cst.EmptyLine(), "\n"), (cst.EmptyLine(whitespace=cst.SimpleWhitespace(" ")), " \n"), (cst.EmptyLine(comment=cst.Comment("# comment")), "# comment\n"), (cst.EmptyLine(newline=cst.Newline("\r\n")), "\r\n"), (DummyIndentedBlock(" ", cst.EmptyLine()), " \n"), (DummyIndentedBlock(" ", cst.EmptyLine(indent=False)), "\n"), ( DummyIndentedBlock( "\t", cst.EmptyLine( whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment("# comment"), newline=cst.Newline("\r\n"), ), ), "\t # comment\r\n", ), )) def test_valid(self, node: cst.CSTNode, code: str) -> None: self.validate_node(node, code)
def test_module_config_for_parsing(self) -> None: module = parse_module("pass\r") statement = parse_statement("if True:\r pass", config=module.config_for_parsing) self.assertEqual( statement, cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[cst.SimpleStatementLine(body=[cst.Pass()])], header=cst.TrailingWhitespace(newline=cst.Newline( # This would be "\r" if we didn't pass the module config forward. value=None)), ), ), )
def test_with_changes(self) -> None: initial = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(" \\\n "), comment=cst.Comment("# initial"), newline=cst.Newline("\r\n"), ) changed = initial.with_changes(comment=cst.Comment("# new comment")) # see that we have the updated fields self.assertEqual(none_throws(changed.comment).value, "# new comment") # and that the old fields are still there self.assertEqual(changed.whitespace.value, " \\\n ") self.assertEqual(changed.newline.value, "\r\n") # ensure no mutation actually happened self.assertEqual(none_throws(initial.comment).value, "# initial")
def leave_SimpleWhitespace( self, original_node: libcst.SimpleWhitespace, updated_node: libcst.SimpleWhitespace, ) -> Union[libcst.SimpleWhitespace, libcst.ParenthesizedWhitespace]: whitespace = original_node.value.replace("\\", "") if "\n" in whitespace: first_line = libcst.TrailingWhitespace( whitespace=libcst.SimpleWhitespace( value=whitespace.split("\n")[0].rstrip() ), comment=None, newline=libcst.Newline(), ) last_line = libcst.SimpleWhitespace(value=whitespace.split("\n")[1]) return libcst.ParenthesizedWhitespace( first_line=first_line, empty_lines=[], indent=True, last_line=last_line ) return updated_node
class NewlineTest(CSTNodeTest): @data_provider(( (cst.Newline("\r\n"), "\r\n"), (cst.Newline("\r"), "\r"), (cst.Newline("\n"), "\n"), )) def test_valid(self, node: cst.CSTNode, code: str) -> None: self.validate_node(node, code) @data_provider(( (lambda: cst.Newline("bad input"), "invalid value"), (lambda: cst.Newline("\nbad input\n"), "invalid value"), (lambda: cst.Newline("\n\n"), "invalid value"), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
class ModuleTest(CSTNodeTest): @data_provider(( # simplest possible program (cst.Module((cst.SimpleStatementLine((cst.Pass(), )), )), "pass\n"), # test default_newline ( cst.Module((cst.SimpleStatementLine((cst.Pass(), )), ), default_newline="\r"), "pass\r", ), # test header/footer ( cst.Module( (cst.SimpleStatementLine((cst.Pass(), )), ), header=(cst.EmptyLine(comment=cst.Comment("# header")), ), footer=(cst.EmptyLine(comment=cst.Comment("# footer")), ), ), "# header\npass\n# footer\n", ), # test has_trailing_newline ( cst.Module( (cst.SimpleStatementLine((cst.Pass(), )), ), has_trailing_newline=False, ), "pass", ), # an empty file (cst.Module((), has_trailing_newline=False), ""), # a file with only comments ( cst.Module( (), header=(cst.EmptyLine( comment=cst.Comment("# nothing to see here")), ), ), "# nothing to see here\n", ), # TODO: test default_indent )) def test_code_and_bytes_properties(self, module: cst.Module, expected: str) -> None: self.assertEqual(module.code, expected) self.assertEqual(module.bytes, expected.encode("utf-8")) @data_provider(( (cst.Module(()), cst.Newline(), "\n"), (cst.Module((), default_newline="\r\n"), cst.Newline(), "\r\n"), # has_trailing_newline has no effect on code_for_node (cst.Module((), has_trailing_newline=False), cst.Newline(), "\n"), # TODO: test default_indent )) def test_code_for_node(self, module: cst.Module, node: cst.CSTNode, expected: str) -> None: self.assertEqual(module.code_for_node(node), expected) @data_provider({ "empty_program": { "code": "", "expected": cst.Module([], has_trailing_newline=False), }, "empty_program_with_newline": { "code": "\n", "expected": cst.Module([], has_trailing_newline=True), "enabled_for_native": False, }, "empty_program_with_comments": { "code": "# some comment\n", "expected": cst.Module( [], header=[cst.EmptyLine(comment=cst.Comment("# some comment"))]), }, "simple_pass": { "code": "pass\n", "expected": cst.Module([cst.SimpleStatementLine([cst.Pass()])]), }, "simple_pass_with_header_footer": { "code": "# header\npass # trailing\n# footer\n", "expected": cst.Module( [ cst.SimpleStatementLine( [cst.Pass()], trailing_whitespace=cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment("# trailing"), ), ) ], header=[cst.EmptyLine(comment=cst.Comment("# header"))], footer=[cst.EmptyLine(comment=cst.Comment("# footer"))], ), }, }) def test_parser(self, *, code: str, expected: cst.Module, enabled_for_native: bool = True) -> None: if is_native() and not enabled_for_native: self.skipTest("Disabled for native parser") self.assertEqual(parse_module(code), expected) @data_provider({ "empty": { "code": "", "expected": CodeRange((1, 0), (1, 0)) }, "empty_with_newline": { "code": "\n", "expected": CodeRange((1, 0), (2, 0)) }, "empty_program_with_comments": { "code": "# 2345", "expected": CodeRange((1, 0), (2, 0)), }, "simple_pass": { "code": "pass\n", "expected": CodeRange((1, 0), (2, 0)) }, "simple_pass_with_header_footer": { "code": "# header\npass # trailing\n# footer\n", "expected": CodeRange((1, 0), (4, 0)), }, }) def test_module_position(self, *, code: str, expected: CodeRange) -> None: wrapper = MetadataWrapper(parse_module(code)) positions = wrapper.resolve(PositionProvider) self.assertEqual(positions[wrapper.module], expected) def cmp_position(self, actual: CodeRange, start: Tuple[int, int], end: Tuple[int, int]) -> None: self.assertEqual(actual, CodeRange(start, end)) def test_function_position(self) -> None: wrapper = MetadataWrapper(parse_module("def foo():\n pass")) module = wrapper.module positions = wrapper.resolve(PositionProvider) fn = cast(cst.FunctionDef, module.body[0]) stmt = cast(cst.SimpleStatementLine, fn.body.body[0]) pass_stmt = cast(cst.Pass, stmt.body[0]) self.cmp_position(positions[stmt], (2, 4), (2, 8)) self.cmp_position(positions[pass_stmt], (2, 4), (2, 8)) def test_nested_indent_position(self) -> None: wrapper = MetadataWrapper( parse_module( "if True:\n if False:\n x = 1\nelse:\n return")) module = wrapper.module positions = wrapper.resolve(PositionProvider) outer_if = cast(cst.If, module.body[0]) inner_if = cast(cst.If, outer_if.body.body[0]) assign = cast(cst.SimpleStatementLine, inner_if.body.body[0]).body[0] outer_else = cast(cst.Else, outer_if.orelse) return_stmt = cast(cst.SimpleStatementLine, outer_else.body.body[0]).body[0] self.cmp_position(positions[outer_if], (1, 0), (5, 10)) self.cmp_position(positions[inner_if], (2, 4), (3, 13)) self.cmp_position(positions[assign], (3, 8), (3, 13)) self.cmp_position(positions[outer_else], (4, 0), (5, 10)) self.cmp_position(positions[return_stmt], (5, 4), (5, 10)) def test_multiline_string_position(self) -> None: wrapper = MetadataWrapper(parse_module('"abc"\\\n"def"')) module = wrapper.module positions = wrapper.resolve(PositionProvider) stmt = cast(cst.SimpleStatementLine, module.body[0]) expr = cast(cst.Expr, stmt.body[0]) string = expr.value self.cmp_position(positions[stmt], (1, 0), (2, 5)) self.cmp_position(positions[expr], (1, 0), (2, 5)) self.cmp_position(positions[string], (1, 0), (2, 5)) def test_module_config_for_parsing(self) -> None: module = parse_module("pass\r") statement = parse_statement("if True:\r pass", config=module.config_for_parsing) self.assertEqual( statement, cst.If( test=cst.Name(value="True"), body=cst.IndentedBlock( body=[cst.SimpleStatementLine(body=[cst.Pass()])], header=cst.TrailingWhitespace(newline=cst.Newline( # This would be "\r" if we didn't pass the module config forward. value=None)), ), ), )
class ModuleTest(CSTNodeTest): @data_provider(( # simplest possible program (cst.Module((cst.SimpleStatementLine((cst.Pass(), )), )), "pass\n"), # test default_newline ( cst.Module((cst.SimpleStatementLine((cst.Pass(), )), ), default_newline="\r"), "pass\r", ), # test header/footer ( cst.Module( (cst.SimpleStatementLine((cst.Pass(), )), ), header=(cst.EmptyLine(comment=cst.Comment("# header")), ), footer=(cst.EmptyLine(comment=cst.Comment("# footer")), ), ), "# header\npass\n# footer\n", ), # test has_trailing_newline ( cst.Module( (cst.SimpleStatementLine((cst.Pass(), )), ), has_trailing_newline=False, ), "pass", ), # an empty file (cst.Module((), has_trailing_newline=False), ""), # a file with only comments ( cst.Module( (), header=(cst.EmptyLine( comment=cst.Comment("# nothing to see here")), ), ), "# nothing to see here\n", ), # TODO: test default_indent )) def test_code_and_bytes_properties(self, module: cst.Module, expected: str) -> None: self.assertEqual(module.code, expected) self.assertEqual(module.bytes, expected.encode("utf-8")) @data_provider(( (cst.Module(()), cst.Newline(), "\n"), (cst.Module((), default_newline="\r\n"), cst.Newline(), "\r\n"), # has_trailing_newline has no effect on code_for_node (cst.Module((), has_trailing_newline=False), cst.Newline(), "\n"), # TODO: test default_indent )) def test_code_for_node(self, module: cst.Module, node: cst.CSTNode, expected: str) -> None: self.assertEqual(module.code_for_node(node), expected) @data_provider({ "empty_program": { "code": "", "expected": cst.Module([], has_trailing_newline=False), }, "empty_program_with_newline": { "code": "\n", "expected": cst.Module([], has_trailing_newline=True), }, "empty_program_with_comments": { "code": "# some comment\n", "expected": cst.Module( [], header=[cst.EmptyLine(comment=cst.Comment("# some comment"))]), }, "simple_pass": { "code": "pass\n", "expected": cst.Module([cst.SimpleStatementLine([cst.Pass()])]), }, "simple_pass_with_header_footer": { "code": "# header\npass # trailing\n# footer\n", "expected": cst.Module( [ cst.SimpleStatementLine( [cst.Pass()], trailing_whitespace=cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment("# trailing"), ), ) ], header=[cst.EmptyLine(comment=cst.Comment("# header"))], footer=[cst.EmptyLine(comment=cst.Comment("# footer"))], ), }, }) def test_parser(self, *, code: str, expected: cst.Module) -> None: self.assertEqual(parse_module(code), expected) @data_provider({ "empty": { "code": "", "expected": CodeRange.create((1, 0), (1, 0)) }, "empty_with_newline": { "code": "\n", "expected": CodeRange.create((1, 0), (2, 0)), }, "empty_program_with_comments": { "code": "# 2345", "expected": CodeRange.create((1, 0), (2, 0)), }, "simple_pass": { "code": "pass\n", "expected": CodeRange.create((1, 0), (2, 0)), }, "simple_pass_with_header_footer": { "code": "# header\npass # trailing\n# footer\n", "expected": CodeRange.create((1, 0), (4, 0)), }, }) def test_module_position(self, *, code: str, expected: CodeRange) -> None: module = parse_module(code) provider = SyntacticPositionProvider() module.code_for_node(module, provider) self.assertEqual(provider._computed[module], expected) # TODO: remove this self.assertEqual(module._metadata[SyntacticPositionProvider], expected) def cmp_position(self, actual: CodeRange, start: Tuple[int, int], end: Tuple[int, int]) -> None: self.assertEqual(actual, CodeRange.create(start, end)) def test_function_position(self) -> None: module = parse_module("def foo():\n pass") provider = SyntacticPositionProvider() module.code_for_node(module, provider) fn = cast(cst.FunctionDef, module.body[0]) stmt = cast(cst.SimpleStatementLine, fn.body.body[0]) pass_stmt = cast(cst.Pass, stmt.body[0]) self.cmp_position(provider._computed[stmt], (2, 4), (2, 8)) self.cmp_position(provider._computed[pass_stmt], (2, 4), (2, 8)) def test_nested_indent_position(self) -> None: module = parse_module( "if True:\n if False:\n x = 1\nelse:\n return") provider = SyntacticPositionProvider() module.code_for_node(module, provider) outer_if = cast(cst.If, module.body[0]) inner_if = cast(cst.If, outer_if.body.body[0]) assign = cast(cst.SimpleStatementLine, inner_if.body.body[0]).body[0] outer_else = cast(cst.Else, outer_if.orelse) return_stmt = cast(cst.SimpleStatementLine, outer_else.body.body[0]).body[0] self.cmp_position(provider._computed[outer_if], (1, 0), (5, 10)) self.cmp_position(provider._computed[inner_if], (2, 4), (3, 13)) self.cmp_position(provider._computed[assign], (3, 8), (3, 13)) self.cmp_position(provider._computed[outer_else], (4, 0), (5, 10)) self.cmp_position(provider._computed[return_stmt], (5, 4), (5, 10)) def test_multiline_string_position(self) -> None: module = parse_module('"abc"\\\n"def"') provider = SyntacticPositionProvider() module.code_for_node(module, provider) stmt = cast(cst.SimpleStatementLine, module.body[0]) expr = cast(cst.Expr, stmt.body[0]) string = expr.value self.cmp_position(provider._computed[stmt], (1, 0), (2, 5)) self.cmp_position(provider._computed[expr], (1, 0), (2, 5)) self.cmp_position(provider._computed[string], (1, 0), (2, 5))
def test_repr(self) -> None: self.assertEqual( repr( cst.SimpleStatementLine( body=[cst.Pass()], # tuple with multiple items leading_lines=( cst.EmptyLine( indent=True, whitespace=cst.SimpleWhitespace(""), comment=None, newline=cst.Newline(), ), cst.EmptyLine( indent=True, whitespace=cst.SimpleWhitespace(""), comment=None, newline=cst.Newline(), ), ), trailing_whitespace=cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment("# comment"), newline=cst.Newline(), ), ) ), dedent( """ SimpleStatementLine( body=[ Pass( semicolon=MaybeSentinel.DEFAULT, ), ], leading_lines=[ EmptyLine( indent=True, whitespace=SimpleWhitespace( value='', ), comment=None, newline=Newline( value=None, ), ), EmptyLine( indent=True, whitespace=SimpleWhitespace( value='', ), comment=None, newline=Newline( value=None, ), ), ], trailing_whitespace=TrailingWhitespace( whitespace=SimpleWhitespace( value=' ', ), comment=Comment( value='# comment', ), newline=Newline( value=None, ), ), ) """ ).strip(), )
class WhitespaceParserTest(UnitTest): @data_provider( { "simple_whitespace_empty": { "parser": parse_simple_whitespace, "config": Config( lines=["not whitespace\n", " another line\n"], default_newline="\n" ), "start_state": State( line=1, column=0, absolute_indent="", is_parenthesized=False ), "end_state": State( line=1, column=0, absolute_indent="", is_parenthesized=False ), "expected_node": cst.SimpleWhitespace(""), }, "simple_whitespace_start_of_line": { "parser": parse_simple_whitespace, "config": Config( lines=["\t <-- There's some whitespace there\n"], default_newline="\n", ), "start_state": State( line=1, column=0, absolute_indent="", is_parenthesized=False ), "end_state": State( line=1, column=3, absolute_indent="", is_parenthesized=False ), "expected_node": cst.SimpleWhitespace("\t "), }, "simple_whitespace_end_of_line": { "parser": parse_simple_whitespace, "config": Config(lines=["prefix "], default_newline="\n"), "start_state": State( line=1, column=6, absolute_indent="", is_parenthesized=False ), "end_state": State( line=1, column=9, absolute_indent="", is_parenthesized=False ), "expected_node": cst.SimpleWhitespace(" "), }, "simple_whitespace_line_continuation": { "parser": parse_simple_whitespace, "config": Config( lines=["prefix \\\n", " \\\n", " # suffix\n"], default_newline="\n", ), "start_state": State( line=1, column=6, absolute_indent="", is_parenthesized=False ), "end_state": State( line=3, column=4, absolute_indent="", is_parenthesized=False ), "expected_node": cst.SimpleWhitespace(" \\\n \\\n "), }, "empty_lines_empty_list": { "parser": parse_empty_lines, "config": Config( lines=["this is not an empty line"], default_newline="\n" ), "start_state": State( line=1, column=0, absolute_indent="", is_parenthesized=False ), "end_state": State( line=1, column=0, absolute_indent="", is_parenthesized=False ), "expected_node": [], }, "empty_lines_single_line": { "parser": parse_empty_lines, "config": Config( lines=[" # comment\n", "this is not an empty line\n"], default_newline="\n", ), "start_state": State( line=1, column=0, absolute_indent=" ", is_parenthesized=False ), "end_state": State( line=2, column=0, absolute_indent=" ", is_parenthesized=False ), "expected_node": [ cst.EmptyLine( indent=True, whitespace=cst.SimpleWhitespace(""), comment=cst.Comment("# comment"), newline=cst.Newline(), ) ], }, "empty_lines_multiple": { "parser": parse_empty_lines, "config": Config( lines=[ "\n", " \n", " # comment with indent and whitespace\n", "# comment without indent\n", " # comment with no indent but some whitespace\n", ], default_newline="\n", ), "start_state": State( line=1, column=0, absolute_indent=" ", is_parenthesized=False ), "end_state": State( line=5, column=47, absolute_indent=" ", is_parenthesized=False ), "expected_node": [ cst.EmptyLine( indent=False, whitespace=cst.SimpleWhitespace(""), comment=None, newline=cst.Newline(), ), cst.EmptyLine( indent=True, whitespace=cst.SimpleWhitespace(""), comment=None, newline=cst.Newline(), ), cst.EmptyLine( indent=True, whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment("# comment with indent and whitespace"), newline=cst.Newline(), ), cst.EmptyLine( indent=False, whitespace=cst.SimpleWhitespace(""), comment=cst.Comment("# comment without indent"), newline=cst.Newline(), ), cst.EmptyLine( indent=False, whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment( "# comment with no indent but some whitespace" ), newline=cst.Newline(), ), ], }, "empty_lines_non_default_newline": { "parser": parse_empty_lines, "config": Config(lines=["\n", "\r\n", "\r"], default_newline="\n"), "start_state": State( line=1, column=0, absolute_indent="", is_parenthesized=False ), "end_state": State( line=3, column=1, absolute_indent="", is_parenthesized=False ), "expected_node": [ cst.EmptyLine( indent=True, whitespace=cst.SimpleWhitespace(""), comment=None, newline=cst.Newline(None), # default newline ), cst.EmptyLine( indent=True, whitespace=cst.SimpleWhitespace(""), comment=None, newline=cst.Newline("\r\n"), # non-default ), cst.EmptyLine( indent=True, whitespace=cst.SimpleWhitespace(""), comment=None, newline=cst.Newline("\r"), # non-default ), ], }, "trailing_whitespace": { "parser": parse_trailing_whitespace, "config": Config( lines=["some code # comment\n"], default_newline="\n" ), "start_state": State( line=1, column=9, absolute_indent="", is_parenthesized=False ), "end_state": State( line=1, column=21, absolute_indent="", is_parenthesized=False ), "expected_node": cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment("# comment"), newline=cst.Newline(), ), }, } ) def test_parsers( self, parser: Callable[[Config, State], _T], config: Config, start_state: State, end_state: State, expected_node: _T, ) -> None: # Uses internal `deep_equals` function instead of `CSTNode.deep_equals`, because # we need to compare sequences of nodes, and this is the easiest way. :/ parsed_node = parser(config, start_state) self.assertTrue( deep_equals(parsed_node, expected_node), msg=f"\n{parsed_node!r}\nis not deeply equal to \n{expected_node!r}", ) self.assertEqual(start_state, end_state)
from itertools import takewhile, dropwhile from typing import Union, Dict, List, Optional, Tuple import libcst as cst from libcst import matchers as m from libcst.helpers import parse_template_statement from pybetter.transformers.base import NoqaAwareTransformer DEFAULT_INIT_TEMPLATE = """if {arg} is None: {arg} = {init} """ # If you do not explicitly set `indent` to False, then even empty line # will contain at least one indent worth of whitespaces. EMPTY_LINE = cst.EmptyLine(indent=False, newline=cst.Newline()) def is_docstring(node): return m.matches( node, m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())])) class ArgEmptyInitTransformer(NoqaAwareTransformer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.module_config = None def visit_Module(self, node: cst.Module) -> Optional[bool]: self.module_config = node.config_for_parsing return True
class WithTest(CSTNodeTest): maxDiff: int = 2000 @data_provider(( # Simple with block { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with context_mgr(): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 24)), }, # Simple async with block { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), "code": "async with context_mgr(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.7")), }, # Python 3.6 async with block { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), asynchronous=cst.Asynchronous(), ), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n async with context_mgr(): pass\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), }, # Multiple context managers { "node": cst.With( ( cst.WithItem(cst.Call(cst.Name("foo"))), cst.WithItem(cst.Call(cst.Name("bar"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with foo(), bar(): pass\n", "parser": None, }, { "node": cst.With( ( cst.WithItem( cst.Call(cst.Name("foo")), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.WithItem(cst.Call(cst.Name("bar"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with foo(), bar(): pass\n", "parser": parse_statement, }, # With block containing variable for context manager. { "node": cst.With( (cst.WithItem( cst.Call(cst.Name("context_mgr")), cst.AsName(cst.Name("ctx")), ), ), cst.SimpleStatementSuite((cst.Pass(), )), ), "code": "with context_mgr() as ctx: pass\n", "parser": parse_statement, }, # indentation { "node": DummyIndentedBlock( " ", cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), ), ), "code": " with context_mgr(): pass\n", "parser": None, "expected_position": CodeRange((1, 4), (1, 28)), }, # with an indented body { "node": DummyIndentedBlock( " ", cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), ), ), "code": " with context_mgr():\n pass\n", "parser": None, "expected_position": CodeRange((1, 4), (2, 12)), }, # leading_lines { "node": cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nwith context_mgr(): pass\n", "parser": parse_statement, "expected_position": CodeRange((2, 0), (2, 24)), }, # Whitespace { "node": cst.With( (cst.WithItem( cst.Call(cst.Name("context_mgr")), cst.AsName( cst.Name("ctx"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), ), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(" "), whitespace_before_colon=cst.SimpleWhitespace(" "), ), "code": "with context_mgr() as ctx : pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 36)), }, # Weird spacing rules, that parse differently depending on whether # we are using a grammar that included parenthesized with statements. { "node": cst.With( (cst.WithItem( cst.Call( cst.Name("context_mgr"), lpar=() if is_native() else (cst.LeftParen(), ), rpar=() if is_native() else (cst.RightParen(), ), )), ), cst.SimpleStatementSuite((cst.Pass(), )), lpar=(cst.LeftParen() if is_native() else MaybeSentinel.DEFAULT), rpar=(cst.RightParen() if is_native() else MaybeSentinel.DEFAULT), whitespace_after_with=cst.SimpleWhitespace(""), ), "code": "with(context_mgr()): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 25)), }, # Multi-line parenthesized with. { "node": cst.With( ( cst.WithItem( cst.Call(cst.Name("foo")), comma=cst. Comma(whitespace_after=cst.ParenthesizedWhitespace( first_line=cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(value="", ), comment=None, newline=cst.Newline(value=None, ), ), empty_lines=[], indent=True, last_line=cst.SimpleWhitespace(value=" ", ), )), ), cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()), ), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), rpar=cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), "code": ("with ( foo(),\n" " bar(), ): pass\n"), # noqa "parser": parse_statement if is_native() else None, "expected_position": CodeRange((1, 0), (2, 21)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( { "get_node": lambda: cst.With((), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), ))), "expected_re": "A With statement must have at least one WithItem", }, { "get_node": lambda: cst.With( (cst.WithItem( cst.Call(cst.Name("foo")), comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ") ), ), ), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), ), "expected_re": "The last WithItem in an unparenthesized With cannot " + "have a trailing comma.", }, { "get_node": lambda: cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after with keyword", }, { "get_node": lambda: cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(""), lpar=cst.LeftParen(), ), "expected_re": "Do not mix concrete LeftParen/RightParen with " + "MaybeSentinel", }, { "get_node": lambda: cst.With( (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_with=cst.SimpleWhitespace(""), rpar=cst.RightParen(), ), "expected_re": "Do not mix concrete LeftParen/RightParen with " + "MaybeSentinel", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs) @data_provider(( { "code": "with a, b: pass", "parser": parse_statement_as(python_version="3.1"), "expect_success": True, }, { "code": "with a, b: pass", "parser": parse_statement_as(python_version="3.0"), "expect_success": False, }, )) def test_versions(self, **kwargs: Any) -> None: if is_native() and not kwargs.get("expect_success", True): self.skipTest("parse errors are disabled for native parser") self.assert_parses(**kwargs) def test_adding_parens(self) -> None: node = cst.With( ( cst.WithItem( cst.Call(cst.Name("foo")), comma=cst.Comma( whitespace_after=cst.ParenthesizedWhitespace(), ), ), cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()), ), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")), ) module = cst.Module([]) self.assertEqual( module.code_for_node(node), ("with ( foo(),\n" "bar(), ): pass\n") # noqa )