def basic_parenthesize( node: libcst.CSTNode, whitespace: Optional[libcst.BaseParenthesizableWhitespace] = None, ) -> libcst.CSTNode: if not hasattr(node, "lpar"): return node if whitespace: return node.with_changes( lpar=[libcst.LeftParen(whitespace_after=whitespace)], rpar=[libcst.RightParen()], ) return node.with_changes(lpar=[libcst.LeftParen()], rpar=[libcst.RightParen()])
def _multiline_rpar(self) -> cst.RightParen: # Return multiline `cst.RightParen`. return cst.RightParen( whitespace_before=ImportTransformer._multiline_parenthesized_whitespace( self._indentation ) )
def leave_Call(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef): """Await calls to `method` of TelegraphApi""" path = [] a = original_node.func while isinstance(a, cst.Attribute) or isinstance(a, cst.Name): if isinstance(a, cst.Attribute): path.append(a.attr.value) else: path.append(a.value) a = a.value # await the call if it's API class method should_await = (path[-2:] == ["session", "self"] or path[-3:] == [ "method", "_telegraph", "self", ] or path[-3:] == [ "upload_file", "_telegraph", "self", ]) if not should_await: return updated_node self.fn_should_async = self.stack # mark current fn as async on leave # await the call return Await( updated_node, lpar=[cst.LeftParen()], rpar=[cst.RightParen()], )
def get_rpar(rpar: Optional[cst.RightParen], location: CodeRange) -> Optional[cst.RightParen]: if not rpar or location.start.line == location.end.line: return rpar else: return cst.RightParen( whitespace_before=cst.ParenthesizedWhitespace())
def leave_Assign(self, original_node: libcst.Assign, updated_node: libcst.Assign) -> libcst.Assign: assign_value = updated_node.value if hasattr(assign_value, "lpar"): parenthesized_value = assign_value.with_changes( lpar=[libcst.LeftParen()], rpar=[libcst.RightParen()]) return updated_node.with_changes(value=parenthesized_value) return updated_node
def parenthesize_using_parent(node: T, parent: libcst.CSTNode) -> T: """Add parentheses to the given node if needed. It will use the parent of the node to decide whether parentheses are required. """ if _needs_parentheses_parent(node, parent): return node.with_changes(lpar=[libcst.LeftParen()], rpar=[libcst.RightParen()]) return node
def parenthesize_using_previous(node: T, previous: libcst.CSTNode) -> T: """Add parentheses to the given node if needed. It will use the previous node this node is replacing to decide whether parentheses are required. Note: this function is not as precise as `parenthesize_using_parent` """ if _needs_parentheses_previous(node, previous): return node.with_changes(lpar=[libcst.LeftParen()], rpar=[libcst.RightParen()]) return node
def leave_Yield(self, original_node, updated_node) -> cst.BaseExpression: append = parse_expr(f'{self.ret_var}.append()') yield_val = updated_node.value # If original expr was "yield a, b" then yield_val compiles to # "a, b" (i.e. no parens) which errors if directly inserted into # foo.append(a, b). So we ensure that the tuple has parentheses. if m.matches(yield_val, m.Tuple()): yield_val = yield_val.with_changes(lpar=[cst.LeftParen()], rpar=[cst.RightParen()]) return append.with_changes(args=[cst.Arg(yield_val)])
def visit_Attribute(self, node: cst.Attribute) -> None: rule_config = self.context.config.rule_config parenthesize_attribute_config = rule_config.get( self.__class__.__name__, {}) if isinstance(parenthesize_attribute_config, dict) and parenthesize_attribute_config.get( "disabled", False): return if len(node.lpar) == 0: new_node = node.with_changes(lpar=[cst.LeftParen()], rpar=[cst.RightParen()]) self.report( node, "All attributes should be parenthesized.", replacement=new_node, )
def test_simple_expression(self) -> None: expression = parse_template_expression( "{a} + {b} + {c}", a=cst.Name("one"), b=cst.Name("two"), c=cst.BinaryOperation( lpar=(cst.LeftParen(),), left=cst.Name("three"), operator=cst.Multiply(), right=cst.Name("four"), rpar=(cst.RightParen(),), ), ) self.assertEqual( self.code(expression), "one + two + (three * four)", )
def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> None: # Skip if our immediate parent is also a ConcatenatedString, since our parent # should've already reported this violation. if isinstance(self.context.node_stack[-2], cst.ConcatenatedString): return # collect nested ConcatenatedString nodes into a flat list from outer to # innermost children children: List[cst.ConcatenatedString] = [] el = node while isinstance(el, cst.ConcatenatedString): children.append(el) # left cannot be a ConcatenatedString, only right can. el = el.right # Build up a replacement by starting with the innermost child replacement = children[-1].right for el in reversed(children): replacement = cst.BinaryOperation( left=el.left, # left is never a ConcatenatedString operator=cst.Add( whitespace_before=el.whitespace_between, whitespace_after=cst.SimpleWhitespace(" "), ), right=replacement, lpar=el.lpar, rpar=el.rpar, ) # A binary operation has a lower priority in the order-of-operations than an # implicitly concatenated string, so we need to make sure the replacement is # parenthesized to make our change safe. if not replacement.lpar: # There's a good chance that the formatting might be messed up by this, but # black should be able to sort it out when it gets run next time. # # Because of the changes needed (e.g. increased indentation of children), # it's not really sane/possible for us to format this any better. replacement = replacement.with_changes(lpar=[cst.LeftParen()], rpar=[cst.RightParen()]) self.report(node, replacement=replacement)
def test_adding_parens(self) -> None: node = cst.With( ( cst.WithItem( cst.Call(cst.Name("foo")), comma=cst.Comma( whitespace_after=cst.ParenthesizedWhitespace(), ), ), cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()), ), cst.SimpleStatementSuite((cst.Pass(), )), lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")), rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")), ) module = cst.Module([]) self.assertEqual( module.code_for_node(node), ("with ( foo(),\n" "bar(), ): pass\n") # noqa )
class YieldConstructionTest(CSTNodeTest): @data_provider(( # Simple yield (cst.Yield(), "yield"), # yield expression (cst.Yield(cst.Name("a")), "yield a"), # yield from expression (cst.Yield(cst.From(cst.Call(cst.Name("a")))), "yield from a()"), # Parenthesizing tests ( cst.Yield( lpar=(cst.LeftParen(), ), value=cst.Integer("5"), rpar=(cst.RightParen(), ), ), "(yield 5)", ), # Whitespace oddities tests ( cst.Yield( cst.Name("a", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), whitespace_after_yield=cst.SimpleWhitespace(""), ), "yield(a)", CodeRange((1, 0), (1, 8)), ), ( cst.Yield( cst.From( cst.Call( cst.Name("a"), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), whitespace_after_from=cst.SimpleWhitespace(""), )), "yield from(a())", ), # Whitespace rendering/parsing tests ( cst.Yield( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Integer("5"), whitespace_after_yield=cst.SimpleWhitespace(" "), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( yield 5 )", ), ( cst.Yield( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.From( cst.Call(cst.Name("bla")), whitespace_after_from=cst.SimpleWhitespace(" "), ), whitespace_after_yield=cst.SimpleWhitespace(" "), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( yield from bla() )", CodeRange((1, 2), (1, 20)), ), # From expression position tests ( cst.From(cst.Integer("5"), whitespace_after_from=cst.SimpleWhitespace(" ")), "from 5", CodeRange((1, 0), (1, 6)), ), )) def test_valid(self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None) -> None: self.validate_node(node, code, expected_position=position) @data_provider(( # Paren validation ( lambda: cst.Yield(lpar=(cst.LeftParen(), )), "left paren without right paren", ), ( lambda: cst.Yield(rpar=(cst.RightParen(), )), "right paren without left paren", ), # Make sure we have adequate space after yield ( lambda: cst.Yield(cst.Name("a"), whitespace_after_yield=cst.SimpleWhitespace("")), "Must have at least one space after 'yield' keyword", ), ( lambda: cst.Yield( cst.From(cst.Call(cst.Name("a"))), whitespace_after_yield=cst.SimpleWhitespace(""), ), "Must have at least one space after 'yield' keyword", ), # MAke sure we have adequate space after from ( lambda: cst.Yield( cst.From( cst.Call(cst.Name("a")), whitespace_after_from=cst.SimpleWhitespace(""), )), "Must have at least one space after 'from' keyword", ), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
class YieldParsingTest(CSTNodeTest): @data_provider(( # Simple yield (cst.Yield(), "yield"), # yield expression ( cst.Yield(cst.Name("a"), whitespace_after_yield=cst.SimpleWhitespace(" ")), "yield a", ), # yield from expression ( cst.Yield( cst.From( cst.Call(cst.Name("a")), whitespace_after_from=cst.SimpleWhitespace(" "), ), whitespace_after_yield=cst.SimpleWhitespace(" "), ), "yield from a()", ), # Parenthesizing tests ( cst.Yield( lpar=(cst.LeftParen(), ), whitespace_after_yield=cst.SimpleWhitespace(" "), value=cst.Integer("5"), rpar=(cst.RightParen(), ), ), "(yield 5)", ), # Whitespace oddities tests ( cst.Yield( cst.Name("a", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), whitespace_after_yield=cst.SimpleWhitespace(""), ), "yield(a)", ), ( cst.Yield( cst.From( cst.Call( cst.Name("a"), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), whitespace_after_from=cst.SimpleWhitespace(""), ), whitespace_after_yield=cst.SimpleWhitespace(" "), ), "yield from(a())", ), # Whitespace rendering/parsing tests ( cst.Yield( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Integer("5"), whitespace_after_yield=cst.SimpleWhitespace(" "), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( yield 5 )", ), ( cst.Yield( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.From( cst.Call(cst.Name("bla")), whitespace_after_from=cst.SimpleWhitespace(" "), ), whitespace_after_yield=cst.SimpleWhitespace(" "), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( yield from bla() )", ), )) def test_valid(self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None) -> None: self.validate_node( node, code, lambda code: ensure_type( ensure_type(parse_statement(code), cst.SimpleStatementLine). body[0], cst.Expr, ).value, ) @data_provider(( { "code": "yield from x", "parser": parse_statement_as(python_version="3.3"), "expect_success": True, }, { "code": "yield from x", "parser": parse_statement_as(python_version="3.1"), "expect_success": False, }, )) def test_versions(self, **kwargs: Any) -> None: self.assert_parses(**kwargs)
class RaiseConstructionTest(CSTNodeTest): @data_provider(( # Simple raise { "node": cst.Raise(), "code": "raise" }, # Raise exception { "node": cst.Raise(cst.Call(cst.Name("Exception"))), "code": "raise Exception()", "expected_position": CodeRange((1, 0), (1, 17)), }, # Raise exception from cause { "node": cst.Raise(cst.Call(cst.Name("Exception")), cst.From(cst.Name("cause"))), "code": "raise Exception() from cause", }, # Whitespace oddities test { "node": cst.Raise( cst.Call( cst.Name("Exception"), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), cst.From( cst.Name("cause", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), whitespace_before_from=cst.SimpleWhitespace(""), whitespace_after_from=cst.SimpleWhitespace(""), ), whitespace_after_raise=cst.SimpleWhitespace(""), ), "code": "raise(Exception())from(cause)", "expected_position": CodeRange((1, 0), (1, 29)), }, { "node": cst.Raise( cst.Call(cst.Name("Exception")), cst.From( cst.Name("cause"), whitespace_before_from=cst.SimpleWhitespace(""), ), ), "code": "raise Exception()from cause", "expected_position": CodeRange((1, 0), (1, 27)), }, # Whitespace rendering test { "node": cst.Raise( exc=cst.Call(cst.Name("Exception")), cause=cst.From( cst.Name("cause"), whitespace_before_from=cst.SimpleWhitespace(" "), whitespace_after_from=cst.SimpleWhitespace(" "), ), whitespace_after_raise=cst.SimpleWhitespace(" "), ), "code": "raise Exception() from cause", "expected_position": CodeRange((1, 0), (1, 31)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( # Validate construction { "get_node": lambda: cst.Raise(cause=cst.From(cst.Name("cause"))), "expected_re": "Must have an 'exc' when specifying 'clause'. on Raise", }, # Validate whitespace handling { "get_node": lambda: cst.Raise( cst.Call(cst.Name("Exception")), whitespace_after_raise=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after 'raise'", }, { "get_node": lambda: cst.Raise( cst.Name("exc"), cst.From( cst.Name("cause"), whitespace_before_from=cst.SimpleWhitespace(""), ), ), "expected_re": "Must have at least one space before 'from'", }, { "get_node": lambda: cst.Raise( cst.Name("exc"), cst.From( cst.Name("cause"), whitespace_after_from=cst.SimpleWhitespace(""), ), ), "expected_re": "Must have at least one space after 'from'", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
class RaiseParsingTest(CSTNodeTest): @data_provider(( # Simple raise { "node": cst.Raise(), "code": "raise" }, # Raise exception { "node": cst.Raise( cst.Call(cst.Name("Exception")), whitespace_after_raise=cst.SimpleWhitespace(" "), ), "code": "raise Exception()", }, # Raise exception from cause { "node": cst.Raise( cst.Call(cst.Name("Exception")), cst.From( cst.Name("cause"), whitespace_before_from=cst.SimpleWhitespace(" "), whitespace_after_from=cst.SimpleWhitespace(" "), ), whitespace_after_raise=cst.SimpleWhitespace(" "), ), "code": "raise Exception() from cause", }, # Whitespace oddities test { "node": cst.Raise( cst.Call( cst.Name("Exception"), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), cst.From( cst.Name("cause", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), whitespace_before_from=cst.SimpleWhitespace(""), whitespace_after_from=cst.SimpleWhitespace(""), ), whitespace_after_raise=cst.SimpleWhitespace(""), ), "code": "raise(Exception())from(cause)", }, { "node": cst.Raise( cst.Call(cst.Name("Exception")), cst.From( cst.Name("cause"), whitespace_before_from=cst.SimpleWhitespace(""), whitespace_after_from=cst.SimpleWhitespace(" "), ), whitespace_after_raise=cst.SimpleWhitespace(" "), ), "code": "raise Exception()from cause", }, # Whitespace rendering test { "node": cst.Raise( exc=cst.Call(cst.Name("Exception")), cause=cst.From( cst.Name("cause"), whitespace_before_from=cst.SimpleWhitespace(" "), whitespace_after_from=cst.SimpleWhitespace(" "), ), whitespace_after_raise=cst.SimpleWhitespace(" "), ), "code": "raise Exception() from cause", }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node( parser=lambda code: ensure_type(parse_statement(code), cst. SimpleStatementLine).body[0], **kwargs, )
class ListTest(CSTNodeTest): # A lot of Element/StarredElement tests are provided by the tests for Tuple, so we # we don't need to duplicate them here. @data_provider([ # one-element list, sentinel comma value { "node": cst.Set([cst.Element(cst.Name("single_element"))]), "code": "{single_element}", "parser": parse_expression, }, # custom whitespace between brackets { "node": cst.Set( [cst.Element(cst.Name("single_element"))], lbrace=cst.LeftCurlyBrace( whitespace_after=cst.SimpleWhitespace("\t")), rbrace=cst.RightCurlyBrace( whitespace_before=cst.SimpleWhitespace(" ")), ), "code": "{\tsingle_element }", "parser": parse_expression, }, # two-element list, sentinel comma value { "node": cst.Set( [cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))]), "code": "{one, two}", "parser": None, }, # with parenthesis { "node": cst.Set( [cst.Element(cst.Name("one"))], lpar=[cst.LeftParen()], rpar=[cst.RightParen()], ), "code": "({one})", "parser": None, }, # starred element { "node": cst.Set([ cst.StarredElement(cst.Name("one")), cst.StarredElement(cst.Name("two")), ]), "code": "{*one, *two}", "parser": None, }, # missing spaces around set, always okay { "node": cst.GeneratorExp( cst.Name("elt"), cst.CompFor( target=cst.Name("elt"), iter=cst.Set([ cst.Element( cst.Name("one"), cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Element(cst.Name("two")), ]), ifs=[ cst.CompIf( cst.Name("test"), whitespace_before=cst.SimpleWhitespace(""), ) ], whitespace_after_in=cst.SimpleWhitespace(""), ), ), "code": "(elt for elt in{one, two}if test)", "parser": parse_expression, }, ]) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( ( lambda: cst.Set( [cst.Element(cst.Name("mismatched"))], lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), (lambda: cst.Set([]), "at least one element"), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re) @data_provider(( { "code": "{*x, 2}", "parser": parse_expression_as(python_version="3.5"), "expect_success": True, }, { "code": "{*x, 2}", "parser": parse_expression_as(python_version="3.3"), "expect_success": False, }, )) def test_versions(self, **kwargs: Any) -> None: if is_native() and not kwargs.get("expect_success", True): self.skipTest("parse errors are disabled for native parser") self.assert_parses(**kwargs)
class WhileTest(CSTNodeTest): @data_provider(( # Simple while block # pyre-fixme[6]: Incompatible parameter type { "node": cst.While(cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), ))), "code": "while iter(): pass\n", "parser": parse_statement, }, # While block with else { "node": cst.While( cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), cst.Else(cst.SimpleStatementSuite((cst.Pass(), ))), ), "code": "while iter(): pass\nelse: pass\n", "parser": parse_statement, }, # indentation { "node": DummyIndentedBlock( " ", cst.While( cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), ), ), "code": " while iter(): pass\n", "parser": None, "expected_position": CodeRange((1, 4), (1, 22)), }, # while an indented body { "node": DummyIndentedBlock( " ", cst.While( cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), ), ), "code": " while iter():\n pass\n", "parser": None, "expected_position": CodeRange((1, 4), (2, 12)), }, # leading_lines { "node": cst.While( cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nwhile iter():\n pass\n", "parser": parse_statement, "expected_position": CodeRange((2, 0), (3, 8)), }, { "node": cst.While( cst.Call(cst.Name("iter")), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )), cst.Else( cst.IndentedBlock((cst.SimpleStatementLine( (cst.Pass(), )), )), leading_lines=(cst.EmptyLine( comment=cst.Comment("# else comment")), ), ), leading_lines=(cst.EmptyLine( comment=cst.Comment("# leading comment")), ), ), "code": "# leading comment\nwhile iter():\n pass\n# else comment\nelse:\n pass\n", "parser": None, "expected_position": CodeRange((2, 0), (6, 8)), }, # Weird spacing rules { "node": cst.While( cst.Call( cst.Name("iter"), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_while=cst.SimpleWhitespace(""), ), "code": "while(iter()): pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 19)), }, # Whitespace { "node": cst.While( cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_while=cst.SimpleWhitespace(" "), whitespace_before_colon=cst.SimpleWhitespace(" "), ), "code": "while iter() : pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (1, 21)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(({ "get_node": lambda: cst.While( cst.Call(cst.Name("iter")), cst.SimpleStatementSuite((cst.Pass(), )), whitespace_after_while=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after 'while' keyword", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
class SubscriptTest(CSTNodeTest): @data_provider(( # Simple subscript expression ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), ), "foo[5]", True, ), # Test creation of subscript with slice/extslice. ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=cst.Integer("2"), step=cst.Integer("3"), )), ), ), "foo[1:2:3]", False, ), ( cst.Subscript( cst.Name("foo"), ( cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=cst.Integer("2"), step=cst.Integer("3"), )), cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), ), "foo[1:2:3, 5]", False, CodeRange((1, 0), (1, 13)), ), # Test parsing of subscript with slice/extslice. ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon(), upper=cst.Integer("2"), second_colon=cst.Colon(), step=cst.Integer("3"), )), ), ), "foo[1:2:3]", True, ), ( cst.Subscript( cst.Name("foo"), ( cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon(), upper=cst.Integer("2"), second_colon=cst.Colon(), step=cst.Integer("3"), ), comma=cst.Comma(), ), cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), ), "foo[1:2:3,5]", True, ), # Some more wild slice creations ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=cst.Integer("2"))), ), ), "foo[1:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=None)), ), ), "foo[1:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=None, upper=cst.Integer("2"))), ), ), "foo[:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=None, step=cst.Integer("3"), )), ), ), "foo[1::3]", False, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=None, upper=None, step=cst.Integer("3"))), ), ), "foo[::3]", False, CodeRange((1, 0), (1, 8)), ), # Some more wild slice parsings ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=cst.Integer("2"))), ), ), "foo[1:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=cst.Integer("1"), upper=None)), ), ), "foo[1:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice(lower=None, upper=cst.Integer("2"))), ), ), "foo[:2]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), upper=None, second_colon=cst.Colon(), step=cst.Integer("3"), )), ), ), "foo[1::3]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=None, upper=None, second_colon=cst.Colon(), step=cst.Integer("3"), )), ), ), "foo[::3]", True, ), # Valid list clone operations rendering ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Slice(lower=None, upper=None)), ), ), "foo[:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=None, upper=None, second_colon=cst.Colon(), step=None, )), ), ), "foo[::]", True, ), # Valid list clone operations parsing ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Slice(lower=None, upper=None)), ), ), "foo[:]", True, ), ( cst.Subscript( cst.Name("foo"), (cst.SubscriptElement( cst.Slice( lower=None, upper=None, second_colon=cst.Colon(), step=None, )), ), ), "foo[::]", True, ), # In parenthesis ( cst.Subscript( lpar=(cst.LeftParen(), ), value=cst.Name("foo"), slice=(cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), rpar=(cst.RightParen(), ), ), "(foo[5])", True, ), # Verify spacing ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=(cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 5 ] )", True, ), ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=(cst.SubscriptElement( cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), )), ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 1 : 2 : 3 ] )", True, ), ( cst.Subscript( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), value=cst.Name("foo"), lbracket=cst.LeftSquareBracket( whitespace_after=cst.SimpleWhitespace(" ")), slice=( cst.SubscriptElement( slice=cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.SubscriptElement(slice=cst.Index(cst.Integer("5"))), ), rbracket=cst.RightSquareBracket( whitespace_before=cst.SimpleWhitespace(" ")), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), whitespace_after_value=cst.SimpleWhitespace(" "), ), "( foo [ 1 : 2 : 3 , 5 ] )", True, CodeRange((1, 2), (1, 24)), ), # Test Index, Slice, SubscriptElement (cst.Index(cst.Integer("5")), "5", False, CodeRange((1, 0), (1, 1))), ( cst.Slice(lower=None, upper=None, second_colon=cst.Colon(), step=None), "::", False, CodeRange((1, 0), (1, 2)), ), ( cst.SubscriptElement( slice=cst.Slice( lower=cst.Integer("1"), first_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), upper=cst.Integer("2"), second_colon=cst.Colon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), step=cst.Integer("3"), ), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), "1 : 2 : 3 , ", False, CodeRange((1, 0), (1, 9)), ), )) def test_valid( self, node: cst.CSTNode, code: str, check_parsing: bool, position: Optional[CodeRange] = None, ) -> None: if check_parsing: self.validate_node(node, code, parse_expression, expected_position=position) else: self.validate_node(node, code, expected_position=position) @data_provider(( ( lambda: cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), lpar=(cst.LeftParen(), ), ), "left paren without right paren", ), ( lambda: cst.Subscript( cst.Name("foo"), (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ), rpar=(cst.RightParen(), ), ), "right paren without left paren", ), (lambda: cst.Subscript(cst.Name("foo"), ()), "empty SubscriptElement"), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
class BooleanOperationTest(CSTNodeTest): @data_provider( ( # Simple boolean operations # pyre-fixme[6]: Incompatible parameter type { "node": cst.BooleanOperation( cst.Name("foo"), cst.And(), cst.Name("bar") ), "code": "foo and bar", "parser": parse_expression, "expected_position": None, }, { "node": cst.BooleanOperation( cst.Name("foo"), cst.Or(), cst.Name("bar") ), "code": "foo or bar", "parser": parse_expression, "expected_position": None, }, # Parenthesized boolean operation { "node": cst.BooleanOperation( lpar=(cst.LeftParen(),), left=cst.Name("foo"), operator=cst.Or(), right=cst.Name("bar"), rpar=(cst.RightParen(),), ), "code": "(foo or bar)", "parser": parse_expression, "expected_position": None, }, { "node": cst.BooleanOperation( left=cst.Name( "foo", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) ), operator=cst.Or( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), right=cst.Name( "bar", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) ), ), "code": "(foo)or(bar)", "parser": parse_expression, "expected_position": CodeRange.create((1, 0), (1, 12)), }, # Make sure that spacing works { "node": cst.BooleanOperation( lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),), left=cst.Name("foo"), operator=cst.And( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), right=cst.Name("bar"), rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),), ), "code": "( foo and bar )", "parser": parse_expression, "expected_position": None, }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider( ( { "get_node": lambda: cst.BooleanOperation( cst.Name("foo"), cst.And(), cst.Name("bar"), lpar=(cst.LeftParen(),) ), "expected_re": "left paren without right paren", }, { "get_node": lambda: cst.BooleanOperation( cst.Name("foo"), cst.And(), cst.Name("bar"), rpar=(cst.RightParen(),), ), "expected_re": "right paren without left paren", }, { "get_node": lambda: cst.BooleanOperation( left=cst.Name("foo"), operator=cst.Or( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), right=cst.Name("bar"), ), "expected_re": "at least one space around boolean operator", }, ) ) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
class ComparisonTest(CSTNodeTest): @data_provider(( # Simple comparison statements ( cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.LessThan(), cst.Integer("5")), ), ), "foo < 5", ), ( cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.NotEqual(), cst.Integer("5")), ), ), "foo != 5", ), ( cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.Is(), cst.Name("True")), )), "foo is True", ), ( cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.IsNot(), cst.Name("False")), ), ), "foo is not False", ), ( cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.In(), cst.Name("bar")), )), "foo in bar", ), ( cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.NotIn(), cst.Name("bar")), ), ), "foo not in bar", ), # Comparison with parens ( cst.Comparison( lpar=(cst.LeftParen(), ), left=cst.Name("foo"), comparisons=(cst.ComparisonTarget( operator=cst.NotIn(), comparator=cst.Name("bar")), ), rpar=(cst.RightParen(), ), ), "(foo not in bar)", ), ( cst.Comparison( left=cst.Name("a", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), comparisons=( cst.ComparisonTarget( operator=cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), comparator=cst.Name("b", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), ), cst.ComparisonTarget( operator=cst.Is( whitespace_before=cst.SimpleWhitespace(""), whitespace_after=cst.SimpleWhitespace(""), ), comparator=cst.Name("c", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), ), ), ), "(a)is(b)is(c)", ), # Valid expressions that look like they shouldn't parse ( cst.Comparison( left=cst.Integer("5"), comparisons=(cst.ComparisonTarget( operator=cst.NotIn( whitespace_before=cst.SimpleWhitespace("")), comparator=cst.Name("bar"), ), ), ), "5not in bar", ), # Validate that spacing works properly ( cst.Comparison( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), left=cst.Name("foo"), comparisons=(cst.ComparisonTarget( operator=cst.NotIn( whitespace_before=cst.SimpleWhitespace(" "), whitespace_between=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), comparator=cst.Name("bar"), ), ), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( foo not in bar )", ), # Do some complex nodes ( cst.Comparison( left=cst.Name("baz"), comparisons=(cst.ComparisonTarget( operator=cst.Equal(), comparator=cst.Comparison( lpar=(cst.LeftParen(), ), left=cst.Name("foo"), comparisons=(cst.ComparisonTarget( operator=cst.NotIn(), comparator=cst.Name("bar")), ), rpar=(cst.RightParen(), ), ), ), ), ), "baz == (foo not in bar)", CodeRange((1, 0), (1, 23)), ), ( cst.Comparison( left=cst.Name("a"), comparisons=( cst.ComparisonTarget(operator=cst.GreaterThan(), comparator=cst.Name("b")), cst.ComparisonTarget(operator=cst.GreaterThan(), comparator=cst.Name("c")), ), ), "a > b > c", CodeRange((1, 0), (1, 9)), ), # Is safe to use with word operators if it's leading/trailing children are ( cst.IfExp( body=cst.Comparison( left=cst.Name("a"), comparisons=(cst.ComparisonTarget( operator=cst.GreaterThan(), comparator=cst.Name( "b", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), ), ), ), test=cst.Comparison( left=cst.Name("c", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), comparisons=(cst.ComparisonTarget( operator=cst.GreaterThan(), comparator=cst.Name("d")), ), ), orelse=cst.Name("e"), whitespace_before_if=cst.SimpleWhitespace(""), whitespace_after_if=cst.SimpleWhitespace(""), ), "a > (b)if(c) > d else e", ), # is safe to use with word operators if entirely surrounded in parenthesis ( cst.IfExp( body=cst.Name("a"), test=cst.Comparison( left=cst.Name("b"), comparisons=(cst.ComparisonTarget( operator=cst.GreaterThan(), comparator=cst.Name("c")), ), lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ), ), orelse=cst.Name("d"), whitespace_after_if=cst.SimpleWhitespace(""), whitespace_before_else=cst.SimpleWhitespace(""), ), "a if(b > c)else d", ), )) def test_valid(self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None) -> None: self.validate_node(node, code, parse_expression, expected_position=position) @data_provider(( ( lambda: cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.LessThan(), cst.Integer("5")), ), lpar=(cst.LeftParen(), ), ), "left paren without right paren", ), ( lambda: cst.Comparison( cst.Name("foo"), (cst.ComparisonTarget(cst.LessThan(), cst.Integer("5")), ), rpar=(cst.RightParen(), ), ), "right paren without left paren", ), ( lambda: cst.Comparison(cst.Name("foo"), ()), "at least one ComparisonTarget", ), ( lambda: cst.Comparison( left=cst.Name("foo"), comparisons=(cst.ComparisonTarget( operator=cst.NotIn(whitespace_before=cst.SimpleWhitespace( "")), comparator=cst.Name("bar"), ), ), ), "at least one space around comparison operator", ), ( lambda: cst.Comparison( left=cst.Name("foo"), comparisons=(cst.ComparisonTarget( operator=cst.NotIn(whitespace_after=cst.SimpleWhitespace( "")), comparator=cst.Name("bar"), ), ), ), "at least one space around comparison operator", ), # multi-target comparisons ( lambda: cst.Comparison( left=cst.Name("a"), comparisons=( cst.ComparisonTarget(operator=cst.Is(), comparator=cst.Name("b")), cst.ComparisonTarget( operator=cst.Is(whitespace_before=cst.SimpleWhitespace( "")), comparator=cst.Name("c"), ), ), ), "at least one space around comparison operator", ), ( lambda: cst.Comparison( left=cst.Name("a"), comparisons=( cst.ComparisonTarget(operator=cst.Is(), comparator=cst.Name("b")), cst.ComparisonTarget( operator=cst.Is(whitespace_after=cst.SimpleWhitespace( "")), comparator=cst.Name("c"), ), ), ), "at least one space around comparison operator", ), # whitespace around the comparision itself # a ifb > c else d ( lambda: cst.IfExp( body=cst.Name("a"), test=cst.Comparison( left=cst.Name("b"), comparisons=(cst. ComparisonTarget(operator=cst.GreaterThan(), comparator=cst.Name("c")), ), ), orelse=cst.Name("d"), whitespace_after_if=cst.SimpleWhitespace(""), ), "Must have at least one space after 'if' keyword.", ), # a if b > celse d ( lambda: cst.IfExp( body=cst.Name("a"), test=cst.Comparison( left=cst.Name("b"), comparisons=(cst. ComparisonTarget(operator=cst.GreaterThan(), comparator=cst.Name("c")), ), ), orelse=cst.Name("d"), whitespace_before_else=cst.SimpleWhitespace(""), ), "Must have at least one space before 'else' keyword.", ), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
class AwaitTest(CSTNodeTest): @data_provider(( # Some simple calls { "node": cst.Await(cst.Name("test")), "code": "await test", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), "expected_position": None, }, { "node": cst.Await(cst.Call(cst.Name("test"))), "code": "await test()", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), "expected_position": None, }, # Whitespace { "node": cst.Await( cst.Name("test"), whitespace_after_await=cst.SimpleWhitespace(" "), lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "code": "( await test )", "parser": lambda code: parse_expression( code, config=PartialParserConfig(python_version="3.7")), "expected_position": CodeRange((1, 2), (1, 13)), }, )) def test_valid_py37(self, **kwargs: Any) -> None: # We don't have sentinel nodes for atoms, so we know that 100% of atoms # can be parsed identically to their creation. self.validate_node(**kwargs) @data_provider(( # Some simple calls { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Expr(cst.Await(cst.Name("test"))), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n await test\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), "expected_position": None, }, { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine( (cst.Expr(cst.Await(cst.Call(cst.Name("test")))), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n await test()\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), "expected_position": None, }, # Whitespace { "node": cst.FunctionDef( cst.Name("foo"), cst.Parameters(), cst.IndentedBlock((cst.SimpleStatementLine((cst.Expr( cst.Await( cst.Name("test"), whitespace_after_await=cst.SimpleWhitespace(" "), lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), )), )), )), asynchronous=cst.Asynchronous(), ), "code": "async def foo():\n ( await test )\n", "parser": lambda code: parse_statement( code, config=PartialParserConfig(python_version="3.6")), "expected_position": None, }, )) def test_valid_py36(self, **kwargs: Any) -> None: # We don't have sentinel nodes for atoms, so we know that 100% of atoms # can be parsed identically to their creation. self.validate_node(**kwargs) @data_provider(( # Expression wrapping parenthesis rules { "get_node": (lambda: cst.Await(cst.Name("foo"), lpar=(cst.LeftParen(), ))), "expected_re": "left paren without right paren", }, { "get_node": (lambda: cst.Await(cst.Name("foo"), rpar=(cst.RightParen(), ))), "expected_re": "right paren without left paren", }, { "get_node": (lambda: cst.Await(cst.Name("foo"), whitespace_after_await=cst.SimpleWhitespace("")) ), "expected_re": "at least one space after await", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
class TryTest(CSTNodeTest): @data_provider( ( # Simple try/except block { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), ), ), ), "code": "try: pass\nexcept: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (2, 12)), }, # Try/except with a class { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("Exception"), ), ), ), "code": "try: pass\nexcept Exception: pass\n", "parser": parse_statement, }, # Try/except with a named class { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("Exception"), name=cst.AsName(cst.Name("exc")), ), ), ), "code": "try: pass\nexcept Exception as exc: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (2, 29)), }, # Try/except with multiple clauses { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("TypeError"), name=cst.AsName(cst.Name("e")), ), cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("KeyError"), name=cst.AsName(cst.Name("e")), ), cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), ), ), ), "code": "try: pass\n" + "except TypeError as e: pass\n" + "except KeyError as e: pass\n" + "except: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (4, 12)), }, # Simple try/finally block { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), "code": "try: pass\nfinally: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (2, 13)), }, # Simple try/except/finally block { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), ), ), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), "code": "try: pass\nexcept: pass\nfinally: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (3, 13)), }, # Simple try/except/else block { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), ), ), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), ), "code": "try: pass\nexcept: pass\nelse: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (3, 10)), }, # Simple try/except/else block/finally { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), ), ), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), "code": "try: pass\nexcept: pass\nelse: pass\nfinally: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (4, 13)), }, # Verify whitespace in various locations { "node": cst.Try( leading_lines=(cst.EmptyLine(comment=cst.Comment("# 1")),), body=cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( leading_lines=(cst.EmptyLine(comment=cst.Comment("# 2")),), type=cst.Name("TypeError"), name=cst.AsName( cst.Name("e"), whitespace_before_as=cst.SimpleWhitespace(" "), whitespace_after_as=cst.SimpleWhitespace(" "), ), whitespace_after_except=cst.SimpleWhitespace(" "), whitespace_before_colon=cst.SimpleWhitespace(" "), body=cst.SimpleStatementSuite((cst.Pass(),)), ), ), orelse=cst.Else( leading_lines=(cst.EmptyLine(comment=cst.Comment("# 3")),), body=cst.SimpleStatementSuite((cst.Pass(),)), whitespace_before_colon=cst.SimpleWhitespace(" "), ), finalbody=cst.Finally( leading_lines=(cst.EmptyLine(comment=cst.Comment("# 4")),), body=cst.SimpleStatementSuite((cst.Pass(),)), whitespace_before_colon=cst.SimpleWhitespace(" "), ), whitespace_before_colon=cst.SimpleWhitespace(" "), ), "code": "# 1\ntry : pass\n# 2\nexcept TypeError as e : pass\n# 3\nelse : pass\n# 4\nfinally : pass\n", "parser": parse_statement, "expected_position": CodeRange((2, 0), (8, 14)), }, # Please don't write code like this { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("TypeError"), name=cst.AsName(cst.Name("e")), ), cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("KeyError"), name=cst.AsName(cst.Name("e")), ), cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), ), ), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), "code": "try: pass\n" + "except TypeError as e: pass\n" + "except KeyError as e: pass\n" + "except: pass\n" + "else: pass\n" + "finally: pass\n", "parser": parse_statement, "expected_position": CodeRange((1, 0), (6, 13)), }, # Verify indentation { "node": DummyIndentedBlock( " ", cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("TypeError"), name=cst.AsName(cst.Name("e")), ), cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("KeyError"), name=cst.AsName(cst.Name("e")), ), cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), ), ), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), ), "code": " try: pass\n" + " except TypeError as e: pass\n" + " except KeyError as e: pass\n" + " except: pass\n" + " else: pass\n" + " finally: pass\n", "parser": None, }, # Verify indentation in bodies { "node": DummyIndentedBlock( " ", cst.Try( cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), handlers=( cst.ExceptHandler( cst.IndentedBlock( (cst.SimpleStatementLine((cst.Pass(),)),) ), whitespace_after_except=cst.SimpleWhitespace(""), ), ), orelse=cst.Else( cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)) ), finalbody=cst.Finally( cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)) ), ), ), "code": " try:\n" + " pass\n" + " except:\n" + " pass\n" + " else:\n" + " pass\n" + " finally:\n" + " pass\n", "parser": None, }, # No space when using grouping parens { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), type=cst.Name( "Exception", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),), ), ), ), ), "code": "try: pass\nexcept(Exception): pass\n", "parser": parse_statement, }, # No space when using tuple { "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_after_except=cst.SimpleWhitespace(""), type=cst.Tuple( [ cst.Element( cst.Name("IOError"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ") ), ), cst.Element(cst.Name("ImportError")), ] ), ), ), ), "code": "try: pass\nexcept(IOError, ImportError): pass\n", "parser": parse_statement, }, ) ) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider( ( { "get_node": lambda: cst.AsName(cst.Name("")), "expected_re": "empty name identifier", }, { "get_node": lambda: cst.AsName( cst.Name("bla"), whitespace_after_as=cst.SimpleWhitespace("") ), "expected_re": "between 'as'", }, { "get_node": lambda: cst.AsName( cst.Name("bla"), whitespace_before_as=cst.SimpleWhitespace("") ), "expected_re": "before 'as'", }, { "get_node": lambda: cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), name=cst.AsName(cst.Name("bla")), ), "expected_re": "name for an empty type", }, { "get_node": lambda: cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("TypeError"), whitespace_after_except=cst.SimpleWhitespace(""), ), "expected_re": "at least one space after except", }, { "get_node": lambda: cst.Try(cst.SimpleStatementSuite((cst.Pass(),))), "expected_re": "at least one ExceptHandler or Finally", }, { "get_node": lambda: cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), "expected_re": "at least one ExceptHandler in order to have an Else", }, ) ) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
def leave_Call( # noqa: C901 self, original_node: cst.Call, updated_node: cst.Call) -> cst.BaseExpression: # Lets figure out if this is a "".format() call extraction = self.extract( updated_node, m.Call(func=m.Attribute( value=m.SaveMatchedNode(m.SimpleString(), "string"), attr=m.Name("format"), )), ) if extraction is not None: fstring: List[cst.BaseFormattedStringContent] = [] inserted_sequence: int = 0 stringnode = cst.ensure_type(extraction["string"], cst.SimpleString) tokens = _get_tokens(stringnode.raw_value) for (literal_text, field_name, format_spec, conversion) in tokens: if literal_text: fstring.append(cst.FormattedStringText(literal_text)) if field_name is None: # This is not a format-specification continue if format_spec is not None and len(format_spec) > 0: # TODO: This is supportable since format specs are compatible # with f-string format specs, but it would require matching # format specifier expansions. self.warn( f"Unsupported format_spec {format_spec} in format() call" ) return updated_node # Auto-insert field sequence if it is empty if field_name == "": field_name = str(inserted_sequence) inserted_sequence += 1 expr = _find_expr_from_field_name(field_name, updated_node.args) if expr is None: # Most likely they used * expansion in a format. self.warn( f"Unsupported field_name {field_name} in format() call" ) return updated_node # Verify that we don't have any comments or newlines. Comments aren't # allowed in f-strings, and newlines need parenthesization. We can # have formattedstrings inside other formattedstrings, but I chose not # to doeal with that for now. if self.findall(expr, m.Comment()): # We could strip comments, but this is a formatting change so # we choose not to for now. self.warn(f"Unsupported comment in format() call") return updated_node if self.findall(expr, m.FormattedString()): self.warn(f"Unsupported f-string in format() call") return updated_node if self.findall(expr, m.Await()): # This is fixed in 3.7 but we don't currently have a flag # to enable/disable it. self.warn(f"Unsupported await in format() call") return updated_node # Stripping newlines is effectively a format-only change. expr = cst.ensure_type( expr.visit(StripNewlinesTransformer(self.context)), cst.BaseExpression, ) # Try our best to swap quotes on any strings that won't fit expr = cst.ensure_type( expr.visit( SwitchStringQuotesTransformer(self.context, stringnode.quote[0])), cst.BaseExpression, ) # Verify that the resulting expression doesn't have a backslash # in it. raw_expr_string = self.module.code_for_node(expr) if "\\" in raw_expr_string: self.warn(f"Unsupported backslash in format expression") return updated_node # For safety sake, if this is a dict/set or dict/set comprehension, # wrap it in parens so that it doesn't accidentally create an # escape. if (raw_expr_string.startswith("{") or raw_expr_string.endswith("}")) and (not expr.lpar or not expr.rpar): expr = expr.with_changes(lpar=[cst.LeftParen()], rpar=[cst.RightParen()]) # Verify that any strings we insert don't have the same quote quote_gatherer = StringQuoteGatherer(self.context) expr.visit(quote_gatherer) for stringend in quote_gatherer.stringends: if stringend in stringnode.quote: self.warn( f"Cannot embed string with same quote from format() call" ) return updated_node fstring.append( cst.FormattedStringExpression(expression=expr, conversion=conversion)) return cst.FormattedString( parts=fstring, start=f"f{stringnode.prefix}{stringnode.quote}", end=stringnode.quote, ) return updated_node
class TupleTest(CSTNodeTest): @data_provider( [ # zero-element tuple {"node": cst.Tuple([]), "code": "()", "parser": parse_expression}, # one-element tuple, sentinel comma value { "node": cst.Tuple([cst.Element(cst.Name("single_element"))]), "code": "(single_element,)", "parser": None, }, { "node": cst.Tuple([cst.StarredElement(cst.Name("single_element"))]), "code": "(*single_element,)", "parser": None, }, # two-element tuple, sentinel comma value { "node": cst.Tuple( [cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))] ), "code": "(one, two)", "parser": None, }, # remove parenthesis { "node": cst.Tuple( [cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))], lpar=[], rpar=[], ), "code": "one, two", "parser": None, }, # add extra parenthesis { "node": cst.Tuple( [cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))], lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen(), cst.RightParen()], ), "code": "((one, two))", "parser": None, }, # starred element { "node": cst.Tuple( [ cst.StarredElement(cst.Name("one")), cst.StarredElement(cst.Name("two")), ] ), "code": "(*one, *two)", "parser": None, }, # custom comma on Element { "node": cst.Tuple( [ cst.Element(cst.Name("one"), comma=cst.Comma()), cst.Element(cst.Name("two"), comma=cst.Comma()), ] ), "code": "(one,two,)", "parser": parse_expression, }, # custom comma on StarredElement { "node": cst.Tuple( [ cst.StarredElement(cst.Name("one"), comma=cst.Comma()), cst.StarredElement(cst.Name("two"), comma=cst.Comma()), ] ), "code": "(*one,*two,)", "parser": parse_expression, "expected_position": CodeRange((1, 1), (1, 11)), }, # custom parenthesis on StarredElement { "node": cst.Tuple( [ cst.StarredElement( cst.Name("abc"), lpar=[cst.LeftParen()], rpar=[cst.RightParen()], comma=cst.Comma(), ) ] ), "code": "((*abc),)", "parser": parse_expression, "expected_position": CodeRange((1, 1), (1, 8)), }, # custom whitespace on StarredElement { "node": cst.Tuple( [ cst.Element(cst.Name("one"), comma=cst.Comma()), cst.StarredElement( cst.Name("two"), whitespace_before_value=cst.SimpleWhitespace(" "), lpar=[cst.LeftParen()], rpar=[cst.RightParen()], ), ], lpar=[], rpar=[], # rpar can't own the trailing whitespace if it's not there ), "code": "one,(* two)", "parser": parse_expression, "expected_position": CodeRange((1, 0), (1, 12)), }, # missing spaces around tuple, okay with parenthesis { "node": cst.For( target=cst.Tuple( [ cst.Element(cst.Name("k"), comma=cst.Comma()), cst.Element(cst.Name("v")), ] ), iter=cst.Name("abc"), body=cst.SimpleStatementSuite([cst.Pass()]), whitespace_after_for=cst.SimpleWhitespace(""), whitespace_before_in=cst.SimpleWhitespace(""), ), "code": "for(k,v)in abc: pass\n", "parser": parse_statement, }, # no spaces around tuple, but using values that are parenthesized { "node": cst.For( target=cst.Tuple( [ cst.Element( cst.Name( "k", lpar=[cst.LeftParen()], rpar=[cst.RightParen()] ), comma=cst.Comma(), ), cst.Element( cst.Name( "v", lpar=[cst.LeftParen()], rpar=[cst.RightParen()] ) ), ], lpar=[], rpar=[], ), iter=cst.Name("abc"), body=cst.SimpleStatementSuite([cst.Pass()]), whitespace_after_for=cst.SimpleWhitespace(""), whitespace_before_in=cst.SimpleWhitespace(""), ), "code": "for(k),(v)in abc: pass\n", "parser": parse_statement, }, # starred elements are safe to use without a space before them { "node": cst.For( target=cst.Tuple( [cst.StarredElement(cst.Name("foo"), comma=cst.Comma())], lpar=[], rpar=[], ), iter=cst.Name("bar"), body=cst.SimpleStatementSuite([cst.Pass()]), whitespace_after_for=cst.SimpleWhitespace(""), ), "code": "for*foo, in bar: pass\n", "parser": parse_statement, }, # a trailing comma doesn't mess up TrailingWhitespace { "node": cst.SimpleStatementLine( [ cst.Expr( cst.Tuple( [ cst.Element(cst.Name("one"), comma=cst.Comma()), cst.Element(cst.Name("two"), comma=cst.Comma()), ], lpar=[], rpar=[], ) ) ], trailing_whitespace=cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment("# comment"), ), ), "code": "one,two, # comment\n", "parser": parse_statement, }, ] ) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider( ( ( lambda: cst.Tuple([], lpar=[], rpar=[]), "A zero-length tuple must be wrapped in parentheses.", ), ( lambda: cst.Tuple( [cst.Element(cst.Name("mismatched"))], lpar=[cst.LeftParen(), cst.LeftParen()], rpar=[cst.RightParen()], ), "unbalanced parens", ), ( lambda: cst.For( target=cst.Tuple([cst.Element(cst.Name("el"))], lpar=[], rpar=[]), iter=cst.Name("it"), body=cst.SimpleStatementSuite([cst.Pass()]), whitespace_after_for=cst.SimpleWhitespace(""), ), "Must have at least one space after 'for' keyword.", ), ( lambda: cst.For( target=cst.Tuple([cst.Element(cst.Name("el"))], lpar=[], rpar=[]), iter=cst.Name("it"), body=cst.SimpleStatementSuite([cst.Pass()]), whitespace_before_in=cst.SimpleWhitespace(""), ), "Must have at least one space before 'in' keyword.", ), # an additional check for StarredElement, since it's a separate codepath ( lambda: cst.For( target=cst.Tuple( [cst.StarredElement(cst.Name("el"))], lpar=[], rpar=[] ), iter=cst.Name("it"), body=cst.SimpleStatementSuite([cst.Pass()]), whitespace_before_in=cst.SimpleWhitespace(""), ), "Must have at least one space before 'in' keyword.", ), ) ) def test_invalid( self, get_node: Callable[[], cst.CSTNode], expected_re: str ) -> None: self.assert_invalid(get_node, expected_re)
class IfExpTest(CSTNodeTest): @data_provider(( # Simple if experessions ( cst.IfExp(body=cst.Name("foo"), test=cst.Name("bar"), orelse=cst.Name("baz")), "foo if bar else baz", ), # Parenthesized if expressions ( cst.IfExp( lpar=(cst.LeftParen(), ), body=cst.Name("foo"), test=cst.Name("bar"), orelse=cst.Name("baz"), rpar=(cst.RightParen(), ), ), "(foo if bar else baz)", ), ( cst.IfExp( body=cst.Name("foo", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), whitespace_before_if=cst.SimpleWhitespace(""), whitespace_after_if=cst.SimpleWhitespace(""), test=cst.Name("bar", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), whitespace_before_else=cst.SimpleWhitespace(""), whitespace_after_else=cst.SimpleWhitespace(""), orelse=cst.Name("baz", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), )), ), "(foo)if(bar)else(baz)", CodeRange((1, 0), (1, 21)), ), # Make sure that spacing works ( cst.IfExp( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), body=cst.Name("foo"), whitespace_before_if=cst.SimpleWhitespace(" "), whitespace_after_if=cst.SimpleWhitespace(" "), test=cst.Name("bar"), whitespace_before_else=cst.SimpleWhitespace(" "), whitespace_after_else=cst.SimpleWhitespace(" "), orelse=cst.Name("baz"), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "( foo if bar else baz )", CodeRange((1, 2), (1, 25)), ), )) def test_valid(self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None) -> None: self.validate_node(node, code, parse_expression, expected_position=position) @data_provider(( ( lambda: cst.IfExp( cst.Name("bar"), cst.Name("foo"), cst.Name("baz"), lpar=(cst.LeftParen(), ), ), "left paren without right paren", ), ( lambda: cst.IfExp( cst.Name("bar"), cst.Name("foo"), cst.Name("baz"), rpar=(cst.RightParen(), ), ), "right paren without left paren", ), )) def test_invalid(self, get_node: Callable[[], cst.CSTNode], expected_re: str) -> None: self.assert_invalid(get_node, expected_re)
class NumberTest(CSTNodeTest): @data_provider( ( # Simple number (cst.Integer("5"), "5", parse_expression), # Negted number ( cst.UnaryOperation(operator=cst.Minus(), expression=cst.Integer("5")), "-5", parse_expression, CodeRange((1, 0), (1, 2)), ), # In parenthesis ( cst.UnaryOperation( lpar=(cst.LeftParen(),), operator=cst.Minus(), expression=cst.Integer("5"), rpar=(cst.RightParen(),), ), "(-5)", parse_expression, CodeRange((1, 1), (1, 3)), ), ( cst.UnaryOperation( lpar=(cst.LeftParen(),), operator=cst.Minus(), expression=cst.Integer( "5", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),) ), rpar=(cst.RightParen(),), ), "(-(5))", parse_expression, CodeRange((1, 1), (1, 5)), ), ( cst.UnaryOperation( operator=cst.Minus(), expression=cst.UnaryOperation( operator=cst.Minus(), expression=cst.Integer("5") ), ), "--5", parse_expression, CodeRange((1, 0), (1, 3)), ), # multiple nested parenthesis ( cst.Integer( "5", lpar=(cst.LeftParen(), cst.LeftParen()), rpar=(cst.RightParen(), cst.RightParen()), ), "((5))", parse_expression, CodeRange((1, 2), (1, 3)), ), ( cst.UnaryOperation( lpar=(cst.LeftParen(),), operator=cst.Plus(), expression=cst.Integer( "5", lpar=(cst.LeftParen(), cst.LeftParen()), rpar=(cst.RightParen(), cst.RightParen()), ), rpar=(cst.RightParen(),), ), "(+((5)))", parse_expression, CodeRange((1, 1), (1, 7)), ), ) ) def test_valid( self, node: cst.CSTNode, code: str, parser: Optional[Callable[[str], cst.CSTNode]], position: Optional[CodeRange] = None, ) -> None: self.validate_node(node, code, parser, expected_position=position) @data_provider( ( ( lambda: cst.Integer("5", lpar=(cst.LeftParen(),)), "left paren without right paren", ), ( lambda: cst.Integer("5", rpar=(cst.RightParen(),)), "right paren without left paren", ), ( lambda: cst.Float("5.5", lpar=(cst.LeftParen(),)), "left paren without right paren", ), ( lambda: cst.Float("5.5", rpar=(cst.RightParen(),)), "right paren without left paren", ), ( lambda: cst.Imaginary("5i", lpar=(cst.LeftParen(),)), "left paren without right paren", ), ( lambda: cst.Imaginary("5i", rpar=(cst.RightParen(),)), "right paren without left paren", ), ) ) def test_invalid( self, get_node: Callable[[], cst.CSTNode], expected_re: str ) -> None: self.assert_invalid(get_node, expected_re)
class NamedExprTest(CSTNodeTest): @data_provider(( # Simple named expression { "node": cst.NamedExpr(cst.Name("x"), cst.Float("5.5")), "code": "x := 5.5", "parser": None, # Walrus operator is illegal as top-level statement "expected_position": None, }, # Parenthesized named expression { "node": cst.NamedExpr( lpar=(cst.LeftParen(), ), target=cst.Name("foo"), value=cst.Integer("5"), rpar=(cst.RightParen(), ), ), "code": "(foo := 5)", "parser": _parse_expression_force_38, "expected_position": CodeRange((1, 1), (1, 9)), }, # Make sure that spacing works { "node": cst.NamedExpr( lpar=(cst.LeftParen( whitespace_after=cst.SimpleWhitespace(" ")), ), target=cst.Name("foo"), whitespace_before_walrus=cst.SimpleWhitespace(" "), whitespace_after_walrus=cst.SimpleWhitespace(" "), value=cst.Name("bar"), rpar=(cst.RightParen( whitespace_before=cst.SimpleWhitespace(" ")), ), ), "code": "( foo := bar )", "parser": _parse_expression_force_38, "expected_position": CodeRange((1, 2), (1, 14)), }, # Make sure we can use these where allowed in if/while statements { "node": cst.While( test=cst.NamedExpr( target=cst.Name(value="x"), value=cst.Call(func=cst.Name(value="some_input")), ), body=cst.SimpleStatementSuite(body=[cst.Pass()]), ), "code": "while x := some_input(): pass\n", "parser": _parse_statement_force_38, "expected_position": None, }, { "node": cst.If( test=cst.NamedExpr( target=cst.Name(value="x"), value=cst.Call(func=cst.Name(value="some_input")), ), body=cst.SimpleStatementSuite(body=[cst.Pass()]), ), "code": "if x := some_input(): pass\n", "parser": _parse_statement_force_38, "expected_position": None, }, { "node": cst.If( test=cst.NamedExpr( target=cst.Name(value="x"), value=cst.Integer(value="1"), whitespace_before_walrus=cst.SimpleWhitespace(""), whitespace_after_walrus=cst.SimpleWhitespace(""), ), body=cst.SimpleStatementSuite(body=[cst.Pass()]), ), "code": "if x:=1: pass\n", "parser": _parse_statement_force_38, "expected_position": None, }, # Function args { "node": cst.Call( func=cst.Name(value="f"), args=[ cst.Arg(value=cst.NamedExpr( target=cst.Name(value="y"), value=cst.Integer(value="1"), whitespace_before_walrus=cst.SimpleWhitespace(""), whitespace_after_walrus=cst.SimpleWhitespace(""), )), ], ), "code": "f(y:=1)", "parser": _parse_expression_force_38, "expected_position": None, }, # Whitespace handling on args is fragile { "node": cst.Call( func=cst.Name(value="f"), args=[ cst.Arg( value=cst.Name(value="x"), comma=cst.Comma( whitespace_after=cst.SimpleWhitespace(" ")), ), cst.Arg( value=cst.NamedExpr( target=cst.Name(value="y"), value=cst.Integer(value="1"), whitespace_before_walrus=cst.SimpleWhitespace( " "), whitespace_after_walrus=cst.SimpleWhitespace( " "), ), whitespace_after_arg=cst.SimpleWhitespace(" "), ), ], ), "code": "f(x, y := 1 )", "parser": _parse_expression_force_38, "expected_position": None, }, { "node": cst.Call( func=cst.Name(value="f"), args=[ cst.Arg( value=cst.NamedExpr( target=cst.Name(value="y"), value=cst.Integer(value="1"), whitespace_before_walrus=cst.SimpleWhitespace( " "), whitespace_after_walrus=cst.SimpleWhitespace( " "), ), whitespace_after_arg=cst.SimpleWhitespace(" "), ), ], whitespace_before_args=cst.SimpleWhitespace(" "), ), "code": "f( y := 1 )", "parser": _parse_expression_force_38, "expected_position": None, }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( { "get_node": (lambda: cst.NamedExpr( cst.Name("foo"), cst.Name("bar"), lpar=(cst.LeftParen(), ))), "expected_re": "left paren without right paren", }, { "get_node": (lambda: cst.NamedExpr( cst.Name("foo"), cst.Name("bar"), rpar=(cst.RightParen(), ))), "expected_re": "right paren without left paren", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)
def import_to_node_multi(imp: SortableImport, module: cst.Module) -> cst.BaseStatement: body: List[cst.BaseSmallStatement] = [] names: List[cst.ImportAlias] = [] prev: Optional[cst.ImportAlias] = None following: List[str] = [] lpar_lines: List[cst.EmptyLine] = [] lpar_inline: cst.TrailingWhitespace = cst.TrailingWhitespace() item_count = len(imp.items) for idx, item in enumerate(imp.items): name = name_to_node(item.name) asname = cst.AsName( name=cst.Name(item.asname)) if item.asname else None # Leading comments actually have to be trailing comments on the previous node. # That means putting them on the lpar node for the first item if item.comments.before: lines = [ cst.EmptyLine( indent=True, comment=cst.Comment(c), whitespace=cst.SimpleWhitespace(module.default_indent), ) for c in item.comments.before ] if prev is None: lpar_lines.extend(lines) else: prev.comma.whitespace_after.empty_lines.extend( lines) # type: ignore # all items except the last needs whitespace to indent the *next* line/item indent = idx != (len(imp.items) - 1) first_line = cst.TrailingWhitespace() inline = COMMENT_INDENT.join(item.comments.inline) if inline: first_line = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(inline), ) if idx == item_count - 1: following = item.comments.following + imp.comments.final else: following = item.comments.following after = cst.ParenthesizedWhitespace( indent=True, first_line=first_line, empty_lines=[ cst.EmptyLine( indent=True, comment=cst.Comment(c), whitespace=cst.SimpleWhitespace(module.default_indent), ) for c in following ], last_line=cst.SimpleWhitespace( module.default_indent if indent else ""), ) node = cst.ImportAlias( name=name, asname=asname, comma=cst.Comma(whitespace_after=after), ) names.append(node) prev = node # from foo import ( # bar # ) if imp.stem: stem, ndots = split_relative(imp.stem) if not stem: module_name = None else: module_name = name_to_node(stem) relative = (cst.Dot(), ) * ndots # inline comment following lparen if imp.comments.first_inline: inline = COMMENT_INDENT.join(imp.comments.first_inline) lpar_inline = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(inline), ) body = [ cst.ImportFrom( module=module_name, names=names, relative=relative, lpar=cst.LeftParen( whitespace_after=cst.ParenthesizedWhitespace( indent=True, first_line=lpar_inline, empty_lines=lpar_lines, last_line=cst.SimpleWhitespace(module.default_indent), ), ), rpar=cst.RightParen(), ) ] # import foo else: raise ValueError("can't render basic imports on multiple lines") # comment lines above import leading_lines = [ cst.EmptyLine(indent=True, comment=cst.Comment(line)) if line.startswith("#") else cst.EmptyLine(indent=False) for line in imp.comments.before ] # inline comments following import/rparen if imp.comments.last_inline: inline = COMMENT_INDENT.join(imp.comments.last_inline) trailing = cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(COMMENT_INDENT), comment=cst.Comment(inline)) else: trailing = cst.TrailingWhitespace() return cst.SimpleStatementLine( body=body, leading_lines=leading_lines, trailing_whitespace=trailing, )
class NonlocalConstructionTest(CSTNodeTest): @data_provider(( # Single nonlocal statement { "node": cst.Nonlocal((cst.NameItem(cst.Name("a")), )), "code": "nonlocal a", }, # Multiple entries in nonlocal statement { "node": cst.Nonlocal( (cst.NameItem(cst.Name("a")), cst.NameItem(cst.Name("b")))), "code": "nonlocal a, b", "expected_position": CodeRange((1, 0), (1, 13)), }, # Whitespace rendering test { "node": cst.Nonlocal( ( cst.NameItem( cst.Name("a"), comma=cst.Comma( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), cst.NameItem(cst.Name("b")), ), whitespace_after_nonlocal=cst.SimpleWhitespace(" "), ), "code": "nonlocal a , b", "expected_position": CodeRange((1, 0), (1, 17)), }, )) def test_valid(self, **kwargs: Any) -> None: self.validate_node(**kwargs) @data_provider(( # Validate construction { "get_node": lambda: cst.Nonlocal(()), "expected_re": "A Nonlocal statement must have at least one NameItem", }, # Validate whitespace handling { "get_node": lambda: cst.Nonlocal( (cst.NameItem(cst.Name("a")), ), whitespace_after_nonlocal=cst.SimpleWhitespace(""), ), "expected_re": "Must have at least one space after 'nonlocal' keyword", }, # Validate comma handling { "get_node": lambda: cst.Nonlocal( (cst.NameItem(cst.Name("a"), comma=cst.Comma()), )), "expected_re": "The last NameItem in a Nonlocal cannot have a trailing comma", }, # Validate paren handling { "get_node": lambda: cst.Nonlocal((cst.NameItem( cst.Name("a", lpar=(cst.LeftParen(), ), rpar=(cst.RightParen(), ))), )), "expected_re": "Cannot have parens around names in NameItem", }, )) def test_invalid(self, **kwargs: Any) -> None: self.assert_invalid(**kwargs)