예제 #1
0
 def basic_parenthesize(
     node: libcst.CSTNode,
     whitespace: Optional[libcst.BaseParenthesizableWhitespace] = None,
 ) -> libcst.CSTNode:
     if not hasattr(node, "lpar"):
         return node
     if whitespace:
         return node.with_changes(
             lpar=[libcst.LeftParen(whitespace_after=whitespace)],
             rpar=[libcst.RightParen()],
         )
     return node.with_changes(lpar=[libcst.LeftParen()], rpar=[libcst.RightParen()])
예제 #2
0
 def _multiline_rpar(self) -> cst.RightParen:
     # Return multiline `cst.RightParen`.
     return cst.RightParen(
         whitespace_before=ImportTransformer._multiline_parenthesized_whitespace(
             self._indentation
         )
     )
예제 #3
0
    def leave_Call(self, original_node: cst.FunctionDef,
                   updated_node: cst.FunctionDef):
        """Await calls to `method` of TelegraphApi"""
        path = []

        a = original_node.func
        while isinstance(a, cst.Attribute) or isinstance(a, cst.Name):
            if isinstance(a, cst.Attribute):
                path.append(a.attr.value)
            else:
                path.append(a.value)
            a = a.value

        # await the call if it's API class method
        should_await = (path[-2:] == ["session", "self"] or path[-3:] == [
            "method",
            "_telegraph",
            "self",
        ] or path[-3:] == [
            "upload_file",
            "_telegraph",
            "self",
        ])
        if not should_await:
            return updated_node

        self.fn_should_async = self.stack  # mark current fn as async on leave
        # await the call
        return Await(
            updated_node,
            lpar=[cst.LeftParen()],
            rpar=[cst.RightParen()],
        )
예제 #4
0
 def get_rpar(rpar: Optional[cst.RightParen],
              location: CodeRange) -> Optional[cst.RightParen]:
     if not rpar or location.start.line == location.end.line:
         return rpar
     else:
         return cst.RightParen(
             whitespace_before=cst.ParenthesizedWhitespace())
예제 #5
0
 def leave_Assign(self, original_node: libcst.Assign,
                  updated_node: libcst.Assign) -> libcst.Assign:
     assign_value = updated_node.value
     if hasattr(assign_value, "lpar"):
         parenthesized_value = assign_value.with_changes(
             lpar=[libcst.LeftParen()], rpar=[libcst.RightParen()])
         return updated_node.with_changes(value=parenthesized_value)
     return updated_node
예제 #6
0
def parenthesize_using_parent(node: T, parent: libcst.CSTNode) -> T:
    """Add parentheses to the given node if needed.

    It will use the parent of the node to decide whether parentheses are
    required.
    """
    if _needs_parentheses_parent(node, parent):
        return node.with_changes(lpar=[libcst.LeftParen()],
                                 rpar=[libcst.RightParen()])
    return node
예제 #7
0
def parenthesize_using_previous(node: T, previous: libcst.CSTNode) -> T:
    """Add parentheses to the given node if needed.

    It will use the previous node this node is replacing to decide whether
    parentheses are required.

    Note: this function is not as precise as `parenthesize_using_parent`
    """
    if _needs_parentheses_previous(node, previous):
        return node.with_changes(lpar=[libcst.LeftParen()],
                                 rpar=[libcst.RightParen()])
    return node
예제 #8
0
    def leave_Yield(self, original_node, updated_node) -> cst.BaseExpression:
        append = parse_expr(f'{self.ret_var}.append()')
        yield_val = updated_node.value

        # If original expr was "yield a, b" then yield_val compiles to
        # "a, b" (i.e. no parens) which errors if directly inserted into
        # foo.append(a, b). So we ensure that the tuple has parentheses.
        if m.matches(yield_val, m.Tuple()):
            yield_val = yield_val.with_changes(lpar=[cst.LeftParen()],
                                               rpar=[cst.RightParen()])

        return append.with_changes(args=[cst.Arg(yield_val)])
예제 #9
0
 def visit_Attribute(self, node: cst.Attribute) -> None:
     rule_config = self.context.config.rule_config
     parenthesize_attribute_config = rule_config.get(
         self.__class__.__name__, {})
     if isinstance(parenthesize_attribute_config,
                   dict) and parenthesize_attribute_config.get(
                       "disabled", False):
         return
     if len(node.lpar) == 0:
         new_node = node.with_changes(lpar=[cst.LeftParen()],
                                      rpar=[cst.RightParen()])
         self.report(
             node,
             "All attributes should be parenthesized.",
             replacement=new_node,
         )
 def test_simple_expression(self) -> None:
     expression = parse_template_expression(
         "{a} + {b} + {c}",
         a=cst.Name("one"),
         b=cst.Name("two"),
         c=cst.BinaryOperation(
             lpar=(cst.LeftParen(),),
             left=cst.Name("three"),
             operator=cst.Multiply(),
             right=cst.Name("four"),
             rpar=(cst.RightParen(),),
         ),
     )
     self.assertEqual(
         self.code(expression), "one + two + (three * four)",
     )
예제 #11
0
    def visit_ConcatenatedString(self, node: cst.ConcatenatedString) -> None:
        # Skip if our immediate parent is also a ConcatenatedString, since our parent
        # should've already reported this violation.
        if isinstance(self.context.node_stack[-2], cst.ConcatenatedString):
            return

        # collect nested ConcatenatedString nodes into a flat list from outer to
        # innermost children
        children: List[cst.ConcatenatedString] = []
        el = node
        while isinstance(el, cst.ConcatenatedString):
            children.append(el)
            # left cannot be a ConcatenatedString, only right can.
            el = el.right

        # Build up a replacement by starting with the innermost child
        replacement = children[-1].right
        for el in reversed(children):
            replacement = cst.BinaryOperation(
                left=el.left,  # left is never a ConcatenatedString
                operator=cst.Add(
                    whitespace_before=el.whitespace_between,
                    whitespace_after=cst.SimpleWhitespace(" "),
                ),
                right=replacement,
                lpar=el.lpar,
                rpar=el.rpar,
            )

        # A binary operation has a lower priority in the order-of-operations than an
        # implicitly concatenated string, so we need to make sure the replacement is
        # parenthesized to make our change safe.
        if not replacement.lpar:
            # There's a good chance that the formatting might be messed up by this, but
            # black should be able to sort it out when it gets run next time.
            #
            # Because of the changes needed (e.g. increased indentation of children),
            # it's not really sane/possible for us to format this any better.
            replacement = replacement.with_changes(lpar=[cst.LeftParen()],
                                                   rpar=[cst.RightParen()])

        self.report(node, replacement=replacement)
예제 #12
0
 def test_adding_parens(self) -> None:
     node = cst.With(
         (
             cst.WithItem(
                 cst.Call(cst.Name("foo")),
                 comma=cst.Comma(
                     whitespace_after=cst.ParenthesizedWhitespace(), ),
             ),
             cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()),
         ),
         cst.SimpleStatementSuite((cst.Pass(), )),
         lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),
         rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),
     )
     module = cst.Module([])
     self.assertEqual(
         module.code_for_node(node),
         ("with ( foo(),\n"
          "bar(), ): pass\n")  # noqa
     )
예제 #13
0
class YieldConstructionTest(CSTNodeTest):
    @data_provider((
        # Simple yield
        (cst.Yield(), "yield"),
        # yield expression
        (cst.Yield(cst.Name("a")), "yield a"),
        # yield from expression
        (cst.Yield(cst.From(cst.Call(cst.Name("a")))), "yield from a()"),
        # Parenthesizing tests
        (
            cst.Yield(
                lpar=(cst.LeftParen(), ),
                value=cst.Integer("5"),
                rpar=(cst.RightParen(), ),
            ),
            "(yield 5)",
        ),
        # Whitespace oddities tests
        (
            cst.Yield(
                cst.Name("a",
                         lpar=(cst.LeftParen(), ),
                         rpar=(cst.RightParen(), )),
                whitespace_after_yield=cst.SimpleWhitespace(""),
            ),
            "yield(a)",
            CodeRange((1, 0), (1, 8)),
        ),
        (
            cst.Yield(
                cst.From(
                    cst.Call(
                        cst.Name("a"),
                        lpar=(cst.LeftParen(), ),
                        rpar=(cst.RightParen(), ),
                    ),
                    whitespace_after_from=cst.SimpleWhitespace(""),
                )),
            "yield from(a())",
        ),
        # Whitespace rendering/parsing tests
        (
            cst.Yield(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                value=cst.Integer("5"),
                whitespace_after_yield=cst.SimpleWhitespace("  "),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "( yield  5 )",
        ),
        (
            cst.Yield(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                value=cst.From(
                    cst.Call(cst.Name("bla")),
                    whitespace_after_from=cst.SimpleWhitespace("  "),
                ),
                whitespace_after_yield=cst.SimpleWhitespace("  "),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "( yield  from  bla() )",
            CodeRange((1, 2), (1, 20)),
        ),
        # From expression position tests
        (
            cst.From(cst.Integer("5"),
                     whitespace_after_from=cst.SimpleWhitespace(" ")),
            "from 5",
            CodeRange((1, 0), (1, 6)),
        ),
    ))
    def test_valid(self,
                   node: cst.CSTNode,
                   code: str,
                   position: Optional[CodeRange] = None) -> None:
        self.validate_node(node, code, expected_position=position)

    @data_provider((
        # Paren validation
        (
            lambda: cst.Yield(lpar=(cst.LeftParen(), )),
            "left paren without right paren",
        ),
        (
            lambda: cst.Yield(rpar=(cst.RightParen(), )),
            "right paren without left paren",
        ),
        # Make sure we have adequate space after yield
        (
            lambda: cst.Yield(cst.Name("a"),
                              whitespace_after_yield=cst.SimpleWhitespace("")),
            "Must have at least one space after 'yield' keyword",
        ),
        (
            lambda: cst.Yield(
                cst.From(cst.Call(cst.Name("a"))),
                whitespace_after_yield=cst.SimpleWhitespace(""),
            ),
            "Must have at least one space after 'yield' keyword",
        ),
        # MAke sure we have adequate space after from
        (
            lambda: cst.Yield(
                cst.From(
                    cst.Call(cst.Name("a")),
                    whitespace_after_from=cst.SimpleWhitespace(""),
                )),
            "Must have at least one space after 'from' keyword",
        ),
    ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)
예제 #14
0
class YieldParsingTest(CSTNodeTest):
    @data_provider((
        # Simple yield
        (cst.Yield(), "yield"),
        # yield expression
        (
            cst.Yield(cst.Name("a"),
                      whitespace_after_yield=cst.SimpleWhitespace(" ")),
            "yield a",
        ),
        # yield from expression
        (
            cst.Yield(
                cst.From(
                    cst.Call(cst.Name("a")),
                    whitespace_after_from=cst.SimpleWhitespace(" "),
                ),
                whitespace_after_yield=cst.SimpleWhitespace(" "),
            ),
            "yield from a()",
        ),
        # Parenthesizing tests
        (
            cst.Yield(
                lpar=(cst.LeftParen(), ),
                whitespace_after_yield=cst.SimpleWhitespace(" "),
                value=cst.Integer("5"),
                rpar=(cst.RightParen(), ),
            ),
            "(yield 5)",
        ),
        # Whitespace oddities tests
        (
            cst.Yield(
                cst.Name("a",
                         lpar=(cst.LeftParen(), ),
                         rpar=(cst.RightParen(), )),
                whitespace_after_yield=cst.SimpleWhitespace(""),
            ),
            "yield(a)",
        ),
        (
            cst.Yield(
                cst.From(
                    cst.Call(
                        cst.Name("a"),
                        lpar=(cst.LeftParen(), ),
                        rpar=(cst.RightParen(), ),
                    ),
                    whitespace_after_from=cst.SimpleWhitespace(""),
                ),
                whitespace_after_yield=cst.SimpleWhitespace(" "),
            ),
            "yield from(a())",
        ),
        # Whitespace rendering/parsing tests
        (
            cst.Yield(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                value=cst.Integer("5"),
                whitespace_after_yield=cst.SimpleWhitespace("  "),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "( yield  5 )",
        ),
        (
            cst.Yield(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                value=cst.From(
                    cst.Call(cst.Name("bla")),
                    whitespace_after_from=cst.SimpleWhitespace("  "),
                ),
                whitespace_after_yield=cst.SimpleWhitespace("  "),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "( yield  from  bla() )",
        ),
    ))
    def test_valid(self,
                   node: cst.CSTNode,
                   code: str,
                   position: Optional[CodeRange] = None) -> None:
        self.validate_node(
            node,
            code,
            lambda code: ensure_type(
                ensure_type(parse_statement(code), cst.SimpleStatementLine).
                body[0],
                cst.Expr,
            ).value,
        )

    @data_provider((
        {
            "code": "yield from x",
            "parser": parse_statement_as(python_version="3.3"),
            "expect_success": True,
        },
        {
            "code": "yield from x",
            "parser": parse_statement_as(python_version="3.1"),
            "expect_success": False,
        },
    ))
    def test_versions(self, **kwargs: Any) -> None:
        self.assert_parses(**kwargs)
예제 #15
0
class RaiseConstructionTest(CSTNodeTest):
    @data_provider((
        # Simple raise
        {
            "node": cst.Raise(),
            "code": "raise"
        },
        # Raise exception
        {
            "node": cst.Raise(cst.Call(cst.Name("Exception"))),
            "code": "raise Exception()",
            "expected_position": CodeRange((1, 0), (1, 17)),
        },
        # Raise exception from cause
        {
            "node":
            cst.Raise(cst.Call(cst.Name("Exception")),
                      cst.From(cst.Name("cause"))),
            "code":
            "raise Exception() from cause",
        },
        # Whitespace oddities test
        {
            "node":
            cst.Raise(
                cst.Call(
                    cst.Name("Exception"),
                    lpar=(cst.LeftParen(), ),
                    rpar=(cst.RightParen(), ),
                ),
                cst.From(
                    cst.Name("cause",
                             lpar=(cst.LeftParen(), ),
                             rpar=(cst.RightParen(), )),
                    whitespace_before_from=cst.SimpleWhitespace(""),
                    whitespace_after_from=cst.SimpleWhitespace(""),
                ),
                whitespace_after_raise=cst.SimpleWhitespace(""),
            ),
            "code":
            "raise(Exception())from(cause)",
            "expected_position":
            CodeRange((1, 0), (1, 29)),
        },
        {
            "node":
            cst.Raise(
                cst.Call(cst.Name("Exception")),
                cst.From(
                    cst.Name("cause"),
                    whitespace_before_from=cst.SimpleWhitespace(""),
                ),
            ),
            "code":
            "raise Exception()from cause",
            "expected_position":
            CodeRange((1, 0), (1, 27)),
        },
        # Whitespace rendering test
        {
            "node":
            cst.Raise(
                exc=cst.Call(cst.Name("Exception")),
                cause=cst.From(
                    cst.Name("cause"),
                    whitespace_before_from=cst.SimpleWhitespace("  "),
                    whitespace_after_from=cst.SimpleWhitespace("  "),
                ),
                whitespace_after_raise=cst.SimpleWhitespace("  "),
            ),
            "code":
            "raise  Exception()  from  cause",
            "expected_position":
            CodeRange((1, 0), (1, 31)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider((
        # Validate construction
        {
            "get_node": lambda: cst.Raise(cause=cst.From(cst.Name("cause"))),
            "expected_re":
            "Must have an 'exc' when specifying 'clause'. on Raise",
        },
        # Validate whitespace handling
        {
            "get_node":
            lambda: cst.Raise(
                cst.Call(cst.Name("Exception")),
                whitespace_after_raise=cst.SimpleWhitespace(""),
            ),
            "expected_re":
            "Must have at least one space after 'raise'",
        },
        {
            "get_node":
            lambda: cst.Raise(
                cst.Name("exc"),
                cst.From(
                    cst.Name("cause"),
                    whitespace_before_from=cst.SimpleWhitespace(""),
                ),
            ),
            "expected_re":
            "Must have at least one space before 'from'",
        },
        {
            "get_node":
            lambda: cst.Raise(
                cst.Name("exc"),
                cst.From(
                    cst.Name("cause"),
                    whitespace_after_from=cst.SimpleWhitespace(""),
                ),
            ),
            "expected_re":
            "Must have at least one space after 'from'",
        },
    ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
예제 #16
0
class RaiseParsingTest(CSTNodeTest):
    @data_provider((
        # Simple raise
        {
            "node": cst.Raise(),
            "code": "raise"
        },
        # Raise exception
        {
            "node":
            cst.Raise(
                cst.Call(cst.Name("Exception")),
                whitespace_after_raise=cst.SimpleWhitespace(" "),
            ),
            "code":
            "raise Exception()",
        },
        # Raise exception from cause
        {
            "node":
            cst.Raise(
                cst.Call(cst.Name("Exception")),
                cst.From(
                    cst.Name("cause"),
                    whitespace_before_from=cst.SimpleWhitespace(" "),
                    whitespace_after_from=cst.SimpleWhitespace(" "),
                ),
                whitespace_after_raise=cst.SimpleWhitespace(" "),
            ),
            "code":
            "raise Exception() from cause",
        },
        # Whitespace oddities test
        {
            "node":
            cst.Raise(
                cst.Call(
                    cst.Name("Exception"),
                    lpar=(cst.LeftParen(), ),
                    rpar=(cst.RightParen(), ),
                ),
                cst.From(
                    cst.Name("cause",
                             lpar=(cst.LeftParen(), ),
                             rpar=(cst.RightParen(), )),
                    whitespace_before_from=cst.SimpleWhitespace(""),
                    whitespace_after_from=cst.SimpleWhitespace(""),
                ),
                whitespace_after_raise=cst.SimpleWhitespace(""),
            ),
            "code":
            "raise(Exception())from(cause)",
        },
        {
            "node":
            cst.Raise(
                cst.Call(cst.Name("Exception")),
                cst.From(
                    cst.Name("cause"),
                    whitespace_before_from=cst.SimpleWhitespace(""),
                    whitespace_after_from=cst.SimpleWhitespace(" "),
                ),
                whitespace_after_raise=cst.SimpleWhitespace(" "),
            ),
            "code":
            "raise Exception()from cause",
        },
        # Whitespace rendering test
        {
            "node":
            cst.Raise(
                exc=cst.Call(cst.Name("Exception")),
                cause=cst.From(
                    cst.Name("cause"),
                    whitespace_before_from=cst.SimpleWhitespace("  "),
                    whitespace_after_from=cst.SimpleWhitespace("  "),
                ),
                whitespace_after_raise=cst.SimpleWhitespace("  "),
            ),
            "code":
            "raise  Exception()  from  cause",
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(
            parser=lambda code: ensure_type(parse_statement(code), cst.
                                            SimpleStatementLine).body[0],
            **kwargs,
        )
예제 #17
0
class ListTest(CSTNodeTest):

    # A lot of Element/StarredElement tests are provided by the tests for Tuple, so we
    # we don't need to duplicate them here.
    @data_provider([
        # one-element list, sentinel comma value
        {
            "node": cst.Set([cst.Element(cst.Name("single_element"))]),
            "code": "{single_element}",
            "parser": parse_expression,
        },
        # custom whitespace between brackets
        {
            "node":
            cst.Set(
                [cst.Element(cst.Name("single_element"))],
                lbrace=cst.LeftCurlyBrace(
                    whitespace_after=cst.SimpleWhitespace("\t")),
                rbrace=cst.RightCurlyBrace(
                    whitespace_before=cst.SimpleWhitespace("    ")),
            ),
            "code":
            "{\tsingle_element    }",
            "parser":
            parse_expression,
        },
        # two-element list, sentinel comma value
        {
            "node":
            cst.Set(
                [cst.Element(cst.Name("one")),
                 cst.Element(cst.Name("two"))]),
            "code":
            "{one, two}",
            "parser":
            None,
        },
        # with parenthesis
        {
            "node":
            cst.Set(
                [cst.Element(cst.Name("one"))],
                lpar=[cst.LeftParen()],
                rpar=[cst.RightParen()],
            ),
            "code":
            "({one})",
            "parser":
            None,
        },
        # starred element
        {
            "node":
            cst.Set([
                cst.StarredElement(cst.Name("one")),
                cst.StarredElement(cst.Name("two")),
            ]),
            "code":
            "{*one, *two}",
            "parser":
            None,
        },
        # missing spaces around set, always okay
        {
            "node":
            cst.GeneratorExp(
                cst.Name("elt"),
                cst.CompFor(
                    target=cst.Name("elt"),
                    iter=cst.Set([
                        cst.Element(
                            cst.Name("one"),
                            cst.Comma(
                                whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.Element(cst.Name("two")),
                    ]),
                    ifs=[
                        cst.CompIf(
                            cst.Name("test"),
                            whitespace_before=cst.SimpleWhitespace(""),
                        )
                    ],
                    whitespace_after_in=cst.SimpleWhitespace(""),
                ),
            ),
            "code":
            "(elt for elt in{one, two}if test)",
            "parser":
            parse_expression,
        },
    ])
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider((
        (
            lambda: cst.Set(
                [cst.Element(cst.Name("mismatched"))],
                lpar=[cst.LeftParen(), cst.LeftParen()],
                rpar=[cst.RightParen()],
            ),
            "unbalanced parens",
        ),
        (lambda: cst.Set([]), "at least one element"),
    ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)

    @data_provider((
        {
            "code": "{*x, 2}",
            "parser": parse_expression_as(python_version="3.5"),
            "expect_success": True,
        },
        {
            "code": "{*x, 2}",
            "parser": parse_expression_as(python_version="3.3"),
            "expect_success": False,
        },
    ))
    def test_versions(self, **kwargs: Any) -> None:
        if is_native() and not kwargs.get("expect_success", True):
            self.skipTest("parse errors are disabled for native parser")
        self.assert_parses(**kwargs)
예제 #18
0
class WhileTest(CSTNodeTest):
    @data_provider((
        # Simple while block
        # pyre-fixme[6]: Incompatible parameter type
        {
            "node":
            cst.While(cst.Call(cst.Name("iter")),
                      cst.SimpleStatementSuite((cst.Pass(), ))),
            "code":
            "while iter(): pass\n",
            "parser":
            parse_statement,
        },
        # While block with else
        {
            "node":
            cst.While(
                cst.Call(cst.Name("iter")),
                cst.SimpleStatementSuite((cst.Pass(), )),
                cst.Else(cst.SimpleStatementSuite((cst.Pass(), ))),
            ),
            "code":
            "while iter(): pass\nelse: pass\n",
            "parser":
            parse_statement,
        },
        # indentation
        {
            "node":
            DummyIndentedBlock(
                "    ",
                cst.While(
                    cst.Call(cst.Name("iter")),
                    cst.SimpleStatementSuite((cst.Pass(), )),
                ),
            ),
            "code":
            "    while iter(): pass\n",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 4), (1, 22)),
        },
        # while an indented body
        {
            "node":
            DummyIndentedBlock(
                "    ",
                cst.While(
                    cst.Call(cst.Name("iter")),
                    cst.IndentedBlock((cst.SimpleStatementLine(
                        (cst.Pass(), )), )),
                ),
            ),
            "code":
            "    while iter():\n        pass\n",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 4), (2, 12)),
        },
        # leading_lines
        {
            "node":
            cst.While(
                cst.Call(cst.Name("iter")),
                cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )),
                leading_lines=(cst.EmptyLine(
                    comment=cst.Comment("# leading comment")), ),
            ),
            "code":
            "# leading comment\nwhile iter():\n    pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((2, 0), (3, 8)),
        },
        {
            "node":
            cst.While(
                cst.Call(cst.Name("iter")),
                cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )),
                cst.Else(
                    cst.IndentedBlock((cst.SimpleStatementLine(
                        (cst.Pass(), )), )),
                    leading_lines=(cst.EmptyLine(
                        comment=cst.Comment("# else comment")), ),
                ),
                leading_lines=(cst.EmptyLine(
                    comment=cst.Comment("# leading comment")), ),
            ),
            "code":
            "# leading comment\nwhile iter():\n    pass\n# else comment\nelse:\n    pass\n",
            "parser":
            None,
            "expected_position":
            CodeRange((2, 0), (6, 8)),
        },
        # Weird spacing rules
        {
            "node":
            cst.While(
                cst.Call(
                    cst.Name("iter"),
                    lpar=(cst.LeftParen(), ),
                    rpar=(cst.RightParen(), ),
                ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_while=cst.SimpleWhitespace(""),
            ),
            "code":
            "while(iter()): pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 19)),
        },
        # Whitespace
        {
            "node":
            cst.While(
                cst.Call(cst.Name("iter")),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_while=cst.SimpleWhitespace("  "),
                whitespace_before_colon=cst.SimpleWhitespace("  "),
            ),
            "code":
            "while  iter()  : pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 21)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(({
        "get_node":
        lambda: cst.While(
            cst.Call(cst.Name("iter")),
            cst.SimpleStatementSuite((cst.Pass(), )),
            whitespace_after_while=cst.SimpleWhitespace(""),
        ),
        "expected_re":
        "Must have at least one space after 'while' keyword",
    }, ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
예제 #19
0
class SubscriptTest(CSTNodeTest):
    @data_provider((
        # Simple subscript expression
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ),
            ),
            "foo[5]",
            True,
        ),
        # Test creation of subscript with slice/extslice.
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(
                        lower=cst.Integer("1"),
                        upper=cst.Integer("2"),
                        step=cst.Integer("3"),
                    )), ),
            ),
            "foo[1:2:3]",
            False,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (
                    cst.SubscriptElement(
                        cst.Slice(
                            lower=cst.Integer("1"),
                            upper=cst.Integer("2"),
                            step=cst.Integer("3"),
                        )),
                    cst.SubscriptElement(cst.Index(cst.Integer("5"))),
                ),
            ),
            "foo[1:2:3, 5]",
            False,
            CodeRange((1, 0), (1, 13)),
        ),
        # Test parsing of subscript with slice/extslice.
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(
                        lower=cst.Integer("1"),
                        first_colon=cst.Colon(),
                        upper=cst.Integer("2"),
                        second_colon=cst.Colon(),
                        step=cst.Integer("3"),
                    )), ),
            ),
            "foo[1:2:3]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (
                    cst.SubscriptElement(
                        cst.Slice(
                            lower=cst.Integer("1"),
                            first_colon=cst.Colon(),
                            upper=cst.Integer("2"),
                            second_colon=cst.Colon(),
                            step=cst.Integer("3"),
                        ),
                        comma=cst.Comma(),
                    ),
                    cst.SubscriptElement(cst.Index(cst.Integer("5"))),
                ),
            ),
            "foo[1:2:3,5]",
            True,
        ),
        # Some more wild slice creations
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(lower=cst.Integer("1"),
                              upper=cst.Integer("2"))), ),
            ),
            "foo[1:2]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(lower=cst.Integer("1"), upper=None)), ),
            ),
            "foo[1:]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(lower=None, upper=cst.Integer("2"))), ),
            ),
            "foo[:2]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(
                        lower=cst.Integer("1"),
                        upper=None,
                        step=cst.Integer("3"),
                    )), ),
            ),
            "foo[1::3]",
            False,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(lower=None, upper=None,
                              step=cst.Integer("3"))), ),
            ),
            "foo[::3]",
            False,
            CodeRange((1, 0), (1, 8)),
        ),
        # Some more wild slice parsings
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(lower=cst.Integer("1"),
                              upper=cst.Integer("2"))), ),
            ),
            "foo[1:2]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(lower=cst.Integer("1"), upper=None)), ),
            ),
            "foo[1:]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(lower=None, upper=cst.Integer("2"))), ),
            ),
            "foo[:2]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(
                        lower=cst.Integer("1"),
                        upper=None,
                        second_colon=cst.Colon(),
                        step=cst.Integer("3"),
                    )), ),
            ),
            "foo[1::3]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(
                        lower=None,
                        upper=None,
                        second_colon=cst.Colon(),
                        step=cst.Integer("3"),
                    )), ),
            ),
            "foo[::3]",
            True,
        ),
        # Valid list clone operations rendering
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(cst.Slice(lower=None, upper=None)), ),
            ),
            "foo[:]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(
                        lower=None,
                        upper=None,
                        second_colon=cst.Colon(),
                        step=None,
                    )), ),
            ),
            "foo[::]",
            True,
        ),
        # Valid list clone operations parsing
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(cst.Slice(lower=None, upper=None)), ),
            ),
            "foo[:]",
            True,
        ),
        (
            cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(
                    cst.Slice(
                        lower=None,
                        upper=None,
                        second_colon=cst.Colon(),
                        step=None,
                    )), ),
            ),
            "foo[::]",
            True,
        ),
        # In parenthesis
        (
            cst.Subscript(
                lpar=(cst.LeftParen(), ),
                value=cst.Name("foo"),
                slice=(cst.SubscriptElement(cst.Index(cst.Integer("5"))), ),
                rpar=(cst.RightParen(), ),
            ),
            "(foo[5])",
            True,
        ),
        # Verify spacing
        (
            cst.Subscript(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                value=cst.Name("foo"),
                lbracket=cst.LeftSquareBracket(
                    whitespace_after=cst.SimpleWhitespace(" ")),
                slice=(cst.SubscriptElement(cst.Index(cst.Integer("5"))), ),
                rbracket=cst.RightSquareBracket(
                    whitespace_before=cst.SimpleWhitespace(" ")),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
                whitespace_after_value=cst.SimpleWhitespace(" "),
            ),
            "( foo [ 5 ] )",
            True,
        ),
        (
            cst.Subscript(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                value=cst.Name("foo"),
                lbracket=cst.LeftSquareBracket(
                    whitespace_after=cst.SimpleWhitespace(" ")),
                slice=(cst.SubscriptElement(
                    cst.Slice(
                        lower=cst.Integer("1"),
                        first_colon=cst.Colon(
                            whitespace_before=cst.SimpleWhitespace(" "),
                            whitespace_after=cst.SimpleWhitespace(" "),
                        ),
                        upper=cst.Integer("2"),
                        second_colon=cst.Colon(
                            whitespace_before=cst.SimpleWhitespace(" "),
                            whitespace_after=cst.SimpleWhitespace(" "),
                        ),
                        step=cst.Integer("3"),
                    )), ),
                rbracket=cst.RightSquareBracket(
                    whitespace_before=cst.SimpleWhitespace(" ")),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
                whitespace_after_value=cst.SimpleWhitespace(" "),
            ),
            "( foo [ 1 : 2 : 3 ] )",
            True,
        ),
        (
            cst.Subscript(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                value=cst.Name("foo"),
                lbracket=cst.LeftSquareBracket(
                    whitespace_after=cst.SimpleWhitespace(" ")),
                slice=(
                    cst.SubscriptElement(
                        slice=cst.Slice(
                            lower=cst.Integer("1"),
                            first_colon=cst.Colon(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace(" "),
                            ),
                            upper=cst.Integer("2"),
                            second_colon=cst.Colon(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace(" "),
                            ),
                            step=cst.Integer("3"),
                        ),
                        comma=cst.Comma(
                            whitespace_before=cst.SimpleWhitespace(" "),
                            whitespace_after=cst.SimpleWhitespace("  "),
                        ),
                    ),
                    cst.SubscriptElement(slice=cst.Index(cst.Integer("5"))),
                ),
                rbracket=cst.RightSquareBracket(
                    whitespace_before=cst.SimpleWhitespace(" ")),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
                whitespace_after_value=cst.SimpleWhitespace(" "),
            ),
            "( foo [ 1 : 2 : 3 ,  5 ] )",
            True,
            CodeRange((1, 2), (1, 24)),
        ),
        # Test Index, Slice, SubscriptElement
        (cst.Index(cst.Integer("5")), "5", False, CodeRange((1, 0), (1, 1))),
        (
            cst.Slice(lower=None,
                      upper=None,
                      second_colon=cst.Colon(),
                      step=None),
            "::",
            False,
            CodeRange((1, 0), (1, 2)),
        ),
        (
            cst.SubscriptElement(
                slice=cst.Slice(
                    lower=cst.Integer("1"),
                    first_colon=cst.Colon(
                        whitespace_before=cst.SimpleWhitespace(" "),
                        whitespace_after=cst.SimpleWhitespace(" "),
                    ),
                    upper=cst.Integer("2"),
                    second_colon=cst.Colon(
                        whitespace_before=cst.SimpleWhitespace(" "),
                        whitespace_after=cst.SimpleWhitespace(" "),
                    ),
                    step=cst.Integer("3"),
                ),
                comma=cst.Comma(
                    whitespace_before=cst.SimpleWhitespace(" "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                ),
            ),
            "1 : 2 : 3 ,  ",
            False,
            CodeRange((1, 0), (1, 9)),
        ),
    ))
    def test_valid(
        self,
        node: cst.CSTNode,
        code: str,
        check_parsing: bool,
        position: Optional[CodeRange] = None,
    ) -> None:
        if check_parsing:
            self.validate_node(node,
                               code,
                               parse_expression,
                               expected_position=position)
        else:
            self.validate_node(node, code, expected_position=position)

    @data_provider((
        (
            lambda: cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ),
                lpar=(cst.LeftParen(), ),
            ),
            "left paren without right paren",
        ),
        (
            lambda: cst.Subscript(
                cst.Name("foo"),
                (cst.SubscriptElement(cst.Index(cst.Integer("5"))), ),
                rpar=(cst.RightParen(), ),
            ),
            "right paren without left paren",
        ),
        (lambda: cst.Subscript(cst.Name("foo"), ()), "empty SubscriptElement"),
    ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)
예제 #20
0
class BooleanOperationTest(CSTNodeTest):
    @data_provider(
        (
            # Simple boolean operations
            # pyre-fixme[6]: Incompatible parameter type
            {
                "node": cst.BooleanOperation(
                    cst.Name("foo"), cst.And(), cst.Name("bar")
                ),
                "code": "foo and bar",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.BooleanOperation(
                    cst.Name("foo"), cst.Or(), cst.Name("bar")
                ),
                "code": "foo or bar",
                "parser": parse_expression,
                "expected_position": None,
            },
            # Parenthesized boolean operation
            {
                "node": cst.BooleanOperation(
                    lpar=(cst.LeftParen(),),
                    left=cst.Name("foo"),
                    operator=cst.Or(),
                    right=cst.Name("bar"),
                    rpar=(cst.RightParen(),),
                ),
                "code": "(foo or bar)",
                "parser": parse_expression,
                "expected_position": None,
            },
            {
                "node": cst.BooleanOperation(
                    left=cst.Name(
                        "foo", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)
                    ),
                    operator=cst.Or(
                        whitespace_before=cst.SimpleWhitespace(""),
                        whitespace_after=cst.SimpleWhitespace(""),
                    ),
                    right=cst.Name(
                        "bar", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)
                    ),
                ),
                "code": "(foo)or(bar)",
                "parser": parse_expression,
                "expected_position": CodeRange.create((1, 0), (1, 12)),
            },
            # Make sure that spacing works
            {
                "node": cst.BooleanOperation(
                    lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),),
                    left=cst.Name("foo"),
                    operator=cst.And(
                        whitespace_before=cst.SimpleWhitespace("  "),
                        whitespace_after=cst.SimpleWhitespace("  "),
                    ),
                    right=cst.Name("bar"),
                    rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),),
                ),
                "code": "( foo  and  bar )",
                "parser": parse_expression,
                "expected_position": None,
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(
        (
            {
                "get_node": lambda: cst.BooleanOperation(
                    cst.Name("foo"), cst.And(), cst.Name("bar"), lpar=(cst.LeftParen(),)
                ),
                "expected_re": "left paren without right paren",
            },
            {
                "get_node": lambda: cst.BooleanOperation(
                    cst.Name("foo"),
                    cst.And(),
                    cst.Name("bar"),
                    rpar=(cst.RightParen(),),
                ),
                "expected_re": "right paren without left paren",
            },
            {
                "get_node": lambda: cst.BooleanOperation(
                    left=cst.Name("foo"),
                    operator=cst.Or(
                        whitespace_before=cst.SimpleWhitespace(""),
                        whitespace_after=cst.SimpleWhitespace(""),
                    ),
                    right=cst.Name("bar"),
                ),
                "expected_re": "at least one space around boolean operator",
            },
        )
    )
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
예제 #21
0
class ComparisonTest(CSTNodeTest):
    @data_provider((
        # Simple comparison statements
        (
            cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.LessThan(), cst.Integer("5")), ),
            ),
            "foo < 5",
        ),
        (
            cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.NotEqual(), cst.Integer("5")), ),
            ),
            "foo != 5",
        ),
        (
            cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.Is(), cst.Name("True")), )),
            "foo is True",
        ),
        (
            cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.IsNot(), cst.Name("False")), ),
            ),
            "foo is not False",
        ),
        (
            cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.In(), cst.Name("bar")), )),
            "foo in bar",
        ),
        (
            cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.NotIn(), cst.Name("bar")), ),
            ),
            "foo not in bar",
        ),
        # Comparison with parens
        (
            cst.Comparison(
                lpar=(cst.LeftParen(), ),
                left=cst.Name("foo"),
                comparisons=(cst.ComparisonTarget(
                    operator=cst.NotIn(), comparator=cst.Name("bar")), ),
                rpar=(cst.RightParen(), ),
            ),
            "(foo not in bar)",
        ),
        (
            cst.Comparison(
                left=cst.Name("a",
                              lpar=(cst.LeftParen(), ),
                              rpar=(cst.RightParen(), )),
                comparisons=(
                    cst.ComparisonTarget(
                        operator=cst.Is(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(""),
                        ),
                        comparator=cst.Name("b",
                                            lpar=(cst.LeftParen(), ),
                                            rpar=(cst.RightParen(), )),
                    ),
                    cst.ComparisonTarget(
                        operator=cst.Is(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(""),
                        ),
                        comparator=cst.Name("c",
                                            lpar=(cst.LeftParen(), ),
                                            rpar=(cst.RightParen(), )),
                    ),
                ),
            ),
            "(a)is(b)is(c)",
        ),
        # Valid expressions that look like they shouldn't parse
        (
            cst.Comparison(
                left=cst.Integer("5"),
                comparisons=(cst.ComparisonTarget(
                    operator=cst.NotIn(
                        whitespace_before=cst.SimpleWhitespace("")),
                    comparator=cst.Name("bar"),
                ), ),
            ),
            "5not in bar",
        ),
        # Validate that spacing works properly
        (
            cst.Comparison(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                left=cst.Name("foo"),
                comparisons=(cst.ComparisonTarget(
                    operator=cst.NotIn(
                        whitespace_before=cst.SimpleWhitespace("  "),
                        whitespace_between=cst.SimpleWhitespace("  "),
                        whitespace_after=cst.SimpleWhitespace("  "),
                    ),
                    comparator=cst.Name("bar"),
                ), ),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "( foo  not  in  bar )",
        ),
        # Do some complex nodes
        (
            cst.Comparison(
                left=cst.Name("baz"),
                comparisons=(cst.ComparisonTarget(
                    operator=cst.Equal(),
                    comparator=cst.Comparison(
                        lpar=(cst.LeftParen(), ),
                        left=cst.Name("foo"),
                        comparisons=(cst.ComparisonTarget(
                            operator=cst.NotIn(),
                            comparator=cst.Name("bar")), ),
                        rpar=(cst.RightParen(), ),
                    ),
                ), ),
            ),
            "baz == (foo not in bar)",
            CodeRange((1, 0), (1, 23)),
        ),
        (
            cst.Comparison(
                left=cst.Name("a"),
                comparisons=(
                    cst.ComparisonTarget(operator=cst.GreaterThan(),
                                         comparator=cst.Name("b")),
                    cst.ComparisonTarget(operator=cst.GreaterThan(),
                                         comparator=cst.Name("c")),
                ),
            ),
            "a > b > c",
            CodeRange((1, 0), (1, 9)),
        ),
        # Is safe to use with word operators if it's leading/trailing children are
        (
            cst.IfExp(
                body=cst.Comparison(
                    left=cst.Name("a"),
                    comparisons=(cst.ComparisonTarget(
                        operator=cst.GreaterThan(),
                        comparator=cst.Name(
                            "b",
                            lpar=(cst.LeftParen(), ),
                            rpar=(cst.RightParen(), ),
                        ),
                    ), ),
                ),
                test=cst.Comparison(
                    left=cst.Name("c",
                                  lpar=(cst.LeftParen(), ),
                                  rpar=(cst.RightParen(), )),
                    comparisons=(cst.ComparisonTarget(
                        operator=cst.GreaterThan(),
                        comparator=cst.Name("d")), ),
                ),
                orelse=cst.Name("e"),
                whitespace_before_if=cst.SimpleWhitespace(""),
                whitespace_after_if=cst.SimpleWhitespace(""),
            ),
            "a > (b)if(c) > d else e",
        ),
        # is safe to use with word operators if entirely surrounded in parenthesis
        (
            cst.IfExp(
                body=cst.Name("a"),
                test=cst.Comparison(
                    left=cst.Name("b"),
                    comparisons=(cst.ComparisonTarget(
                        operator=cst.GreaterThan(),
                        comparator=cst.Name("c")), ),
                    lpar=(cst.LeftParen(), ),
                    rpar=(cst.RightParen(), ),
                ),
                orelse=cst.Name("d"),
                whitespace_after_if=cst.SimpleWhitespace(""),
                whitespace_before_else=cst.SimpleWhitespace(""),
            ),
            "a if(b > c)else d",
        ),
    ))
    def test_valid(self,
                   node: cst.CSTNode,
                   code: str,
                   position: Optional[CodeRange] = None) -> None:
        self.validate_node(node,
                           code,
                           parse_expression,
                           expected_position=position)

    @data_provider((
        (
            lambda: cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.LessThan(), cst.Integer("5")), ),
                lpar=(cst.LeftParen(), ),
            ),
            "left paren without right paren",
        ),
        (
            lambda: cst.Comparison(
                cst.Name("foo"),
                (cst.ComparisonTarget(cst.LessThan(), cst.Integer("5")), ),
                rpar=(cst.RightParen(), ),
            ),
            "right paren without left paren",
        ),
        (
            lambda: cst.Comparison(cst.Name("foo"), ()),
            "at least one ComparisonTarget",
        ),
        (
            lambda: cst.Comparison(
                left=cst.Name("foo"),
                comparisons=(cst.ComparisonTarget(
                    operator=cst.NotIn(whitespace_before=cst.SimpleWhitespace(
                        "")),
                    comparator=cst.Name("bar"),
                ), ),
            ),
            "at least one space around comparison operator",
        ),
        (
            lambda: cst.Comparison(
                left=cst.Name("foo"),
                comparisons=(cst.ComparisonTarget(
                    operator=cst.NotIn(whitespace_after=cst.SimpleWhitespace(
                        "")),
                    comparator=cst.Name("bar"),
                ), ),
            ),
            "at least one space around comparison operator",
        ),
        # multi-target comparisons
        (
            lambda: cst.Comparison(
                left=cst.Name("a"),
                comparisons=(
                    cst.ComparisonTarget(operator=cst.Is(),
                                         comparator=cst.Name("b")),
                    cst.ComparisonTarget(
                        operator=cst.Is(whitespace_before=cst.SimpleWhitespace(
                            "")),
                        comparator=cst.Name("c"),
                    ),
                ),
            ),
            "at least one space around comparison operator",
        ),
        (
            lambda: cst.Comparison(
                left=cst.Name("a"),
                comparisons=(
                    cst.ComparisonTarget(operator=cst.Is(),
                                         comparator=cst.Name("b")),
                    cst.ComparisonTarget(
                        operator=cst.Is(whitespace_after=cst.SimpleWhitespace(
                            "")),
                        comparator=cst.Name("c"),
                    ),
                ),
            ),
            "at least one space around comparison operator",
        ),
        # whitespace around the comparision itself
        # a ifb > c else d
        (
            lambda: cst.IfExp(
                body=cst.Name("a"),
                test=cst.Comparison(
                    left=cst.Name("b"),
                    comparisons=(cst.
                                 ComparisonTarget(operator=cst.GreaterThan(),
                                                  comparator=cst.Name("c")), ),
                ),
                orelse=cst.Name("d"),
                whitespace_after_if=cst.SimpleWhitespace(""),
            ),
            "Must have at least one space after 'if' keyword.",
        ),
        # a if b > celse d
        (
            lambda: cst.IfExp(
                body=cst.Name("a"),
                test=cst.Comparison(
                    left=cst.Name("b"),
                    comparisons=(cst.
                                 ComparisonTarget(operator=cst.GreaterThan(),
                                                  comparator=cst.Name("c")), ),
                ),
                orelse=cst.Name("d"),
                whitespace_before_else=cst.SimpleWhitespace(""),
            ),
            "Must have at least one space before 'else' keyword.",
        ),
    ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)
예제 #22
0
class AwaitTest(CSTNodeTest):
    @data_provider((
        # Some simple calls
        {
            "node":
            cst.Await(cst.Name("test")),
            "code":
            "await test",
            "parser":
            lambda code: parse_expression(
                code, config=PartialParserConfig(python_version="3.7")),
            "expected_position":
            None,
        },
        {
            "node":
            cst.Await(cst.Call(cst.Name("test"))),
            "code":
            "await test()",
            "parser":
            lambda code: parse_expression(
                code, config=PartialParserConfig(python_version="3.7")),
            "expected_position":
            None,
        },
        # Whitespace
        {
            "node":
            cst.Await(
                cst.Name("test"),
                whitespace_after_await=cst.SimpleWhitespace("  "),
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "code":
            "( await  test )",
            "parser":
            lambda code: parse_expression(
                code, config=PartialParserConfig(python_version="3.7")),
            "expected_position":
            CodeRange((1, 2), (1, 13)),
        },
    ))
    def test_valid_py37(self, **kwargs: Any) -> None:
        # We don't have sentinel nodes for atoms, so we know that 100% of atoms
        # can be parsed identically to their creation.
        self.validate_node(**kwargs)

    @data_provider((
        # Some simple calls
        {
            "node":
            cst.FunctionDef(
                cst.Name("foo"),
                cst.Parameters(),
                cst.IndentedBlock((cst.SimpleStatementLine(
                    (cst.Expr(cst.Await(cst.Name("test"))), )), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async def foo():\n    await test\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.6")),
            "expected_position":
            None,
        },
        {
            "node":
            cst.FunctionDef(
                cst.Name("foo"),
                cst.Parameters(),
                cst.IndentedBlock((cst.SimpleStatementLine(
                    (cst.Expr(cst.Await(cst.Call(cst.Name("test")))), )), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async def foo():\n    await test()\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.6")),
            "expected_position":
            None,
        },
        # Whitespace
        {
            "node":
            cst.FunctionDef(
                cst.Name("foo"),
                cst.Parameters(),
                cst.IndentedBlock((cst.SimpleStatementLine((cst.Expr(
                    cst.Await(
                        cst.Name("test"),
                        whitespace_after_await=cst.SimpleWhitespace("  "),
                        lpar=(cst.LeftParen(
                            whitespace_after=cst.SimpleWhitespace(" ")), ),
                        rpar=(cst.RightParen(
                            whitespace_before=cst.SimpleWhitespace(" ")), ),
                    )), )), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async def foo():\n    ( await  test )\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.6")),
            "expected_position":
            None,
        },
    ))
    def test_valid_py36(self, **kwargs: Any) -> None:
        # We don't have sentinel nodes for atoms, so we know that 100% of atoms
        # can be parsed identically to their creation.
        self.validate_node(**kwargs)

    @data_provider((
        # Expression wrapping parenthesis rules
        {
            "get_node":
            (lambda: cst.Await(cst.Name("foo"), lpar=(cst.LeftParen(), ))),
            "expected_re":
            "left paren without right paren",
        },
        {
            "get_node":
            (lambda: cst.Await(cst.Name("foo"), rpar=(cst.RightParen(), ))),
            "expected_re":
            "right paren without left paren",
        },
        {
            "get_node":
            (lambda: cst.Await(cst.Name("foo"),
                               whitespace_after_await=cst.SimpleWhitespace(""))
             ),
            "expected_re":
            "at least one space after await",
        },
    ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
예제 #23
0
class TryTest(CSTNodeTest):
    @data_provider(
        (
            # Simple try/except block
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                        ),
                    ),
                ),
                "code": "try: pass\nexcept: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (2, 12)),
            },
            # Try/except with a class
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            type=cst.Name("Exception"),
                        ),
                    ),
                ),
                "code": "try: pass\nexcept Exception: pass\n",
                "parser": parse_statement,
            },
            # Try/except with a named class
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            type=cst.Name("Exception"),
                            name=cst.AsName(cst.Name("exc")),
                        ),
                    ),
                ),
                "code": "try: pass\nexcept Exception as exc: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (2, 29)),
            },
            # Try/except with multiple clauses
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            type=cst.Name("TypeError"),
                            name=cst.AsName(cst.Name("e")),
                        ),
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            type=cst.Name("KeyError"),
                            name=cst.AsName(cst.Name("e")),
                        ),
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                        ),
                    ),
                ),
                "code": "try: pass\n"
                + "except TypeError as e: pass\n"
                + "except KeyError as e: pass\n"
                + "except: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (4, 12)),
            },
            # Simple try/finally block
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
                ),
                "code": "try: pass\nfinally: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (2, 13)),
            },
            # Simple try/except/finally block
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                        ),
                    ),
                    finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
                ),
                "code": "try: pass\nexcept: pass\nfinally: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (3, 13)),
            },
            # Simple try/except/else block
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                        ),
                    ),
                    orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
                ),
                "code": "try: pass\nexcept: pass\nelse: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (3, 10)),
            },
            # Simple try/except/else block/finally
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                        ),
                    ),
                    orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
                    finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
                ),
                "code": "try: pass\nexcept: pass\nelse: pass\nfinally: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (4, 13)),
            },
            # Verify whitespace in various locations
            {
                "node": cst.Try(
                    leading_lines=(cst.EmptyLine(comment=cst.Comment("# 1")),),
                    body=cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            leading_lines=(cst.EmptyLine(comment=cst.Comment("# 2")),),
                            type=cst.Name("TypeError"),
                            name=cst.AsName(
                                cst.Name("e"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                            whitespace_after_except=cst.SimpleWhitespace("  "),
                            whitespace_before_colon=cst.SimpleWhitespace(" "),
                            body=cst.SimpleStatementSuite((cst.Pass(),)),
                        ),
                    ),
                    orelse=cst.Else(
                        leading_lines=(cst.EmptyLine(comment=cst.Comment("# 3")),),
                        body=cst.SimpleStatementSuite((cst.Pass(),)),
                        whitespace_before_colon=cst.SimpleWhitespace(" "),
                    ),
                    finalbody=cst.Finally(
                        leading_lines=(cst.EmptyLine(comment=cst.Comment("# 4")),),
                        body=cst.SimpleStatementSuite((cst.Pass(),)),
                        whitespace_before_colon=cst.SimpleWhitespace(" "),
                    ),
                    whitespace_before_colon=cst.SimpleWhitespace(" "),
                ),
                "code": "# 1\ntry : pass\n# 2\nexcept  TypeError  as  e : pass\n# 3\nelse : pass\n# 4\nfinally : pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((2, 0), (8, 14)),
            },
            # Please don't write code like this
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            type=cst.Name("TypeError"),
                            name=cst.AsName(cst.Name("e")),
                        ),
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            type=cst.Name("KeyError"),
                            name=cst.AsName(cst.Name("e")),
                        ),
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                        ),
                    ),
                    orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
                    finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
                ),
                "code": "try: pass\n"
                + "except TypeError as e: pass\n"
                + "except KeyError as e: pass\n"
                + "except: pass\n"
                + "else: pass\n"
                + "finally: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (6, 13)),
            },
            # Verify indentation
            {
                "node": DummyIndentedBlock(
                    "    ",
                    cst.Try(
                        cst.SimpleStatementSuite((cst.Pass(),)),
                        handlers=(
                            cst.ExceptHandler(
                                cst.SimpleStatementSuite((cst.Pass(),)),
                                type=cst.Name("TypeError"),
                                name=cst.AsName(cst.Name("e")),
                            ),
                            cst.ExceptHandler(
                                cst.SimpleStatementSuite((cst.Pass(),)),
                                type=cst.Name("KeyError"),
                                name=cst.AsName(cst.Name("e")),
                            ),
                            cst.ExceptHandler(
                                cst.SimpleStatementSuite((cst.Pass(),)),
                                whitespace_after_except=cst.SimpleWhitespace(""),
                            ),
                        ),
                        orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
                        finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
                    ),
                ),
                "code": "    try: pass\n"
                + "    except TypeError as e: pass\n"
                + "    except KeyError as e: pass\n"
                + "    except: pass\n"
                + "    else: pass\n"
                + "    finally: pass\n",
                "parser": None,
            },
            # Verify indentation in bodies
            {
                "node": DummyIndentedBlock(
                    "    ",
                    cst.Try(
                        cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)),
                        handlers=(
                            cst.ExceptHandler(
                                cst.IndentedBlock(
                                    (cst.SimpleStatementLine((cst.Pass(),)),)
                                ),
                                whitespace_after_except=cst.SimpleWhitespace(""),
                            ),
                        ),
                        orelse=cst.Else(
                            cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),))
                        ),
                        finalbody=cst.Finally(
                            cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),))
                        ),
                    ),
                ),
                "code": "    try:\n"
                + "        pass\n"
                + "    except:\n"
                + "        pass\n"
                + "    else:\n"
                + "        pass\n"
                + "    finally:\n"
                + "        pass\n",
                "parser": None,
            },
            # No space when using grouping parens
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                            type=cst.Name(
                                "Exception",
                                lpar=(cst.LeftParen(),),
                                rpar=(cst.RightParen(),),
                            ),
                        ),
                    ),
                ),
                "code": "try: pass\nexcept(Exception): pass\n",
                "parser": parse_statement,
            },
            # No space when using tuple
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                            type=cst.Tuple(
                                [
                                    cst.Element(
                                        cst.Name("IOError"),
                                        comma=cst.Comma(
                                            whitespace_after=cst.SimpleWhitespace(" ")
                                        ),
                                    ),
                                    cst.Element(cst.Name("ImportError")),
                                ]
                            ),
                        ),
                    ),
                ),
                "code": "try: pass\nexcept(IOError, ImportError): pass\n",
                "parser": parse_statement,
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(
        (
            {
                "get_node": lambda: cst.AsName(cst.Name("")),
                "expected_re": "empty name identifier",
            },
            {
                "get_node": lambda: cst.AsName(
                    cst.Name("bla"), whitespace_after_as=cst.SimpleWhitespace("")
                ),
                "expected_re": "between 'as'",
            },
            {
                "get_node": lambda: cst.AsName(
                    cst.Name("bla"), whitespace_before_as=cst.SimpleWhitespace("")
                ),
                "expected_re": "before 'as'",
            },
            {
                "get_node": lambda: cst.ExceptHandler(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    name=cst.AsName(cst.Name("bla")),
                ),
                "expected_re": "name for an empty type",
            },
            {
                "get_node": lambda: cst.ExceptHandler(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    type=cst.Name("TypeError"),
                    whitespace_after_except=cst.SimpleWhitespace(""),
                ),
                "expected_re": "at least one space after except",
            },
            {
                "get_node": lambda: cst.Try(cst.SimpleStatementSuite((cst.Pass(),))),
                "expected_re": "at least one ExceptHandler or Finally",
            },
            {
                "get_node": lambda: cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
                    finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
                ),
                "expected_re": "at least one ExceptHandler in order to have an Else",
            },
        )
    )
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
예제 #24
0
    def leave_Call(  # noqa: C901
            self, original_node: cst.Call,
            updated_node: cst.Call) -> cst.BaseExpression:
        # Lets figure out if this is a "".format() call
        extraction = self.extract(
            updated_node,
            m.Call(func=m.Attribute(
                value=m.SaveMatchedNode(m.SimpleString(), "string"),
                attr=m.Name("format"),
            )),
        )
        if extraction is not None:
            fstring: List[cst.BaseFormattedStringContent] = []
            inserted_sequence: int = 0
            stringnode = cst.ensure_type(extraction["string"],
                                         cst.SimpleString)
            tokens = _get_tokens(stringnode.raw_value)
            for (literal_text, field_name, format_spec, conversion) in tokens:
                if literal_text:
                    fstring.append(cst.FormattedStringText(literal_text))
                if field_name is None:
                    # This is not a format-specification
                    continue
                if format_spec is not None and len(format_spec) > 0:
                    # TODO: This is supportable since format specs are compatible
                    # with f-string format specs, but it would require matching
                    # format specifier expansions.
                    self.warn(
                        f"Unsupported format_spec {format_spec} in format() call"
                    )
                    return updated_node

                # Auto-insert field sequence if it is empty
                if field_name == "":
                    field_name = str(inserted_sequence)
                    inserted_sequence += 1
                expr = _find_expr_from_field_name(field_name,
                                                  updated_node.args)
                if expr is None:
                    # Most likely they used * expansion in a format.
                    self.warn(
                        f"Unsupported field_name {field_name} in format() call"
                    )
                    return updated_node

                # Verify that we don't have any comments or newlines. Comments aren't
                # allowed in f-strings, and newlines need parenthesization. We can
                # have formattedstrings inside other formattedstrings, but I chose not
                # to doeal with that for now.
                if self.findall(expr, m.Comment()):
                    # We could strip comments, but this is a formatting change so
                    # we choose not to for now.
                    self.warn(f"Unsupported comment in format() call")
                    return updated_node
                if self.findall(expr, m.FormattedString()):
                    self.warn(f"Unsupported f-string in format() call")
                    return updated_node
                if self.findall(expr, m.Await()):
                    # This is fixed in 3.7 but we don't currently have a flag
                    # to enable/disable it.
                    self.warn(f"Unsupported await in format() call")
                    return updated_node

                # Stripping newlines is effectively a format-only change.
                expr = cst.ensure_type(
                    expr.visit(StripNewlinesTransformer(self.context)),
                    cst.BaseExpression,
                )

                # Try our best to swap quotes on any strings that won't fit
                expr = cst.ensure_type(
                    expr.visit(
                        SwitchStringQuotesTransformer(self.context,
                                                      stringnode.quote[0])),
                    cst.BaseExpression,
                )

                # Verify that the resulting expression doesn't have a backslash
                # in it.
                raw_expr_string = self.module.code_for_node(expr)
                if "\\" in raw_expr_string:
                    self.warn(f"Unsupported backslash in format expression")
                    return updated_node

                # For safety sake, if this is a dict/set or dict/set comprehension,
                # wrap it in parens so that it doesn't accidentally create an
                # escape.
                if (raw_expr_string.startswith("{")
                        or raw_expr_string.endswith("}")) and (not expr.lpar or
                                                               not expr.rpar):
                    expr = expr.with_changes(lpar=[cst.LeftParen()],
                                             rpar=[cst.RightParen()])

                # Verify that any strings we insert don't have the same quote
                quote_gatherer = StringQuoteGatherer(self.context)
                expr.visit(quote_gatherer)
                for stringend in quote_gatherer.stringends:
                    if stringend in stringnode.quote:
                        self.warn(
                            f"Cannot embed string with same quote from format() call"
                        )
                        return updated_node

                fstring.append(
                    cst.FormattedStringExpression(expression=expr,
                                                  conversion=conversion))
            return cst.FormattedString(
                parts=fstring,
                start=f"f{stringnode.prefix}{stringnode.quote}",
                end=stringnode.quote,
            )

        return updated_node
예제 #25
0
class TupleTest(CSTNodeTest):
    @data_provider(
        [
            # zero-element tuple
            {"node": cst.Tuple([]), "code": "()", "parser": parse_expression},
            # one-element tuple, sentinel comma value
            {
                "node": cst.Tuple([cst.Element(cst.Name("single_element"))]),
                "code": "(single_element,)",
                "parser": None,
            },
            {
                "node": cst.Tuple([cst.StarredElement(cst.Name("single_element"))]),
                "code": "(*single_element,)",
                "parser": None,
            },
            # two-element tuple, sentinel comma value
            {
                "node": cst.Tuple(
                    [cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))]
                ),
                "code": "(one, two)",
                "parser": None,
            },
            # remove parenthesis
            {
                "node": cst.Tuple(
                    [cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))],
                    lpar=[],
                    rpar=[],
                ),
                "code": "one, two",
                "parser": None,
            },
            # add extra parenthesis
            {
                "node": cst.Tuple(
                    [cst.Element(cst.Name("one")), cst.Element(cst.Name("two"))],
                    lpar=[cst.LeftParen(), cst.LeftParen()],
                    rpar=[cst.RightParen(), cst.RightParen()],
                ),
                "code": "((one, two))",
                "parser": None,
            },
            # starred element
            {
                "node": cst.Tuple(
                    [
                        cst.StarredElement(cst.Name("one")),
                        cst.StarredElement(cst.Name("two")),
                    ]
                ),
                "code": "(*one, *two)",
                "parser": None,
            },
            # custom comma on Element
            {
                "node": cst.Tuple(
                    [
                        cst.Element(cst.Name("one"), comma=cst.Comma()),
                        cst.Element(cst.Name("two"), comma=cst.Comma()),
                    ]
                ),
                "code": "(one,two,)",
                "parser": parse_expression,
            },
            # custom comma on StarredElement
            {
                "node": cst.Tuple(
                    [
                        cst.StarredElement(cst.Name("one"), comma=cst.Comma()),
                        cst.StarredElement(cst.Name("two"), comma=cst.Comma()),
                    ]
                ),
                "code": "(*one,*two,)",
                "parser": parse_expression,
                "expected_position": CodeRange((1, 1), (1, 11)),
            },
            # custom parenthesis on StarredElement
            {
                "node": cst.Tuple(
                    [
                        cst.StarredElement(
                            cst.Name("abc"),
                            lpar=[cst.LeftParen()],
                            rpar=[cst.RightParen()],
                            comma=cst.Comma(),
                        )
                    ]
                ),
                "code": "((*abc),)",
                "parser": parse_expression,
                "expected_position": CodeRange((1, 1), (1, 8)),
            },
            # custom whitespace on StarredElement
            {
                "node": cst.Tuple(
                    [
                        cst.Element(cst.Name("one"), comma=cst.Comma()),
                        cst.StarredElement(
                            cst.Name("two"),
                            whitespace_before_value=cst.SimpleWhitespace("  "),
                            lpar=[cst.LeftParen()],
                            rpar=[cst.RightParen()],
                        ),
                    ],
                    lpar=[],
                    rpar=[],  # rpar can't own the trailing whitespace if it's not there
                ),
                "code": "one,(*  two)",
                "parser": parse_expression,
                "expected_position": CodeRange((1, 0), (1, 12)),
            },
            # missing spaces around tuple, okay with parenthesis
            {
                "node": cst.For(
                    target=cst.Tuple(
                        [
                            cst.Element(cst.Name("k"), comma=cst.Comma()),
                            cst.Element(cst.Name("v")),
                        ]
                    ),
                    iter=cst.Name("abc"),
                    body=cst.SimpleStatementSuite([cst.Pass()]),
                    whitespace_after_for=cst.SimpleWhitespace(""),
                    whitespace_before_in=cst.SimpleWhitespace(""),
                ),
                "code": "for(k,v)in abc: pass\n",
                "parser": parse_statement,
            },
            # no spaces around tuple, but using values that are parenthesized
            {
                "node": cst.For(
                    target=cst.Tuple(
                        [
                            cst.Element(
                                cst.Name(
                                    "k", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]
                                ),
                                comma=cst.Comma(),
                            ),
                            cst.Element(
                                cst.Name(
                                    "v", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]
                                )
                            ),
                        ],
                        lpar=[],
                        rpar=[],
                    ),
                    iter=cst.Name("abc"),
                    body=cst.SimpleStatementSuite([cst.Pass()]),
                    whitespace_after_for=cst.SimpleWhitespace(""),
                    whitespace_before_in=cst.SimpleWhitespace(""),
                ),
                "code": "for(k),(v)in abc: pass\n",
                "parser": parse_statement,
            },
            # starred elements are safe to use without a space before them
            {
                "node": cst.For(
                    target=cst.Tuple(
                        [cst.StarredElement(cst.Name("foo"), comma=cst.Comma())],
                        lpar=[],
                        rpar=[],
                    ),
                    iter=cst.Name("bar"),
                    body=cst.SimpleStatementSuite([cst.Pass()]),
                    whitespace_after_for=cst.SimpleWhitespace(""),
                ),
                "code": "for*foo, in bar: pass\n",
                "parser": parse_statement,
            },
            # a trailing comma doesn't mess up TrailingWhitespace
            {
                "node": cst.SimpleStatementLine(
                    [
                        cst.Expr(
                            cst.Tuple(
                                [
                                    cst.Element(cst.Name("one"), comma=cst.Comma()),
                                    cst.Element(cst.Name("two"), comma=cst.Comma()),
                                ],
                                lpar=[],
                                rpar=[],
                            )
                        )
                    ],
                    trailing_whitespace=cst.TrailingWhitespace(
                        whitespace=cst.SimpleWhitespace("  "),
                        comment=cst.Comment("# comment"),
                    ),
                ),
                "code": "one,two,  # comment\n",
                "parser": parse_statement,
            },
        ]
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(
        (
            (
                lambda: cst.Tuple([], lpar=[], rpar=[]),
                "A zero-length tuple must be wrapped in parentheses.",
            ),
            (
                lambda: cst.Tuple(
                    [cst.Element(cst.Name("mismatched"))],
                    lpar=[cst.LeftParen(), cst.LeftParen()],
                    rpar=[cst.RightParen()],
                ),
                "unbalanced parens",
            ),
            (
                lambda: cst.For(
                    target=cst.Tuple([cst.Element(cst.Name("el"))], lpar=[], rpar=[]),
                    iter=cst.Name("it"),
                    body=cst.SimpleStatementSuite([cst.Pass()]),
                    whitespace_after_for=cst.SimpleWhitespace(""),
                ),
                "Must have at least one space after 'for' keyword.",
            ),
            (
                lambda: cst.For(
                    target=cst.Tuple([cst.Element(cst.Name("el"))], lpar=[], rpar=[]),
                    iter=cst.Name("it"),
                    body=cst.SimpleStatementSuite([cst.Pass()]),
                    whitespace_before_in=cst.SimpleWhitespace(""),
                ),
                "Must have at least one space before 'in' keyword.",
            ),
            # an additional check for StarredElement, since it's a separate codepath
            (
                lambda: cst.For(
                    target=cst.Tuple(
                        [cst.StarredElement(cst.Name("el"))], lpar=[], rpar=[]
                    ),
                    iter=cst.Name("it"),
                    body=cst.SimpleStatementSuite([cst.Pass()]),
                    whitespace_before_in=cst.SimpleWhitespace(""),
                ),
                "Must have at least one space before 'in' keyword.",
            ),
        )
    )
    def test_invalid(
        self, get_node: Callable[[], cst.CSTNode], expected_re: str
    ) -> None:
        self.assert_invalid(get_node, expected_re)
예제 #26
0
class IfExpTest(CSTNodeTest):
    @data_provider((
        # Simple if experessions
        (
            cst.IfExp(body=cst.Name("foo"),
                      test=cst.Name("bar"),
                      orelse=cst.Name("baz")),
            "foo if bar else baz",
        ),
        # Parenthesized if expressions
        (
            cst.IfExp(
                lpar=(cst.LeftParen(), ),
                body=cst.Name("foo"),
                test=cst.Name("bar"),
                orelse=cst.Name("baz"),
                rpar=(cst.RightParen(), ),
            ),
            "(foo if bar else baz)",
        ),
        (
            cst.IfExp(
                body=cst.Name("foo",
                              lpar=(cst.LeftParen(), ),
                              rpar=(cst.RightParen(), )),
                whitespace_before_if=cst.SimpleWhitespace(""),
                whitespace_after_if=cst.SimpleWhitespace(""),
                test=cst.Name("bar",
                              lpar=(cst.LeftParen(), ),
                              rpar=(cst.RightParen(), )),
                whitespace_before_else=cst.SimpleWhitespace(""),
                whitespace_after_else=cst.SimpleWhitespace(""),
                orelse=cst.Name("baz",
                                lpar=(cst.LeftParen(), ),
                                rpar=(cst.RightParen(), )),
            ),
            "(foo)if(bar)else(baz)",
            CodeRange((1, 0), (1, 21)),
        ),
        # Make sure that spacing works
        (
            cst.IfExp(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                body=cst.Name("foo"),
                whitespace_before_if=cst.SimpleWhitespace("  "),
                whitespace_after_if=cst.SimpleWhitespace("  "),
                test=cst.Name("bar"),
                whitespace_before_else=cst.SimpleWhitespace("  "),
                whitespace_after_else=cst.SimpleWhitespace("  "),
                orelse=cst.Name("baz"),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "( foo  if  bar  else  baz )",
            CodeRange((1, 2), (1, 25)),
        ),
    ))
    def test_valid(self,
                   node: cst.CSTNode,
                   code: str,
                   position: Optional[CodeRange] = None) -> None:
        self.validate_node(node,
                           code,
                           parse_expression,
                           expected_position=position)

    @data_provider((
        (
            lambda: cst.IfExp(
                cst.Name("bar"),
                cst.Name("foo"),
                cst.Name("baz"),
                lpar=(cst.LeftParen(), ),
            ),
            "left paren without right paren",
        ),
        (
            lambda: cst.IfExp(
                cst.Name("bar"),
                cst.Name("foo"),
                cst.Name("baz"),
                rpar=(cst.RightParen(), ),
            ),
            "right paren without left paren",
        ),
    ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)
예제 #27
0
class NumberTest(CSTNodeTest):
    @data_provider(
        (
            # Simple number
            (cst.Integer("5"), "5", parse_expression),
            # Negted number
            (
                cst.UnaryOperation(operator=cst.Minus(), expression=cst.Integer("5")),
                "-5",
                parse_expression,
                CodeRange((1, 0), (1, 2)),
            ),
            # In parenthesis
            (
                cst.UnaryOperation(
                    lpar=(cst.LeftParen(),),
                    operator=cst.Minus(),
                    expression=cst.Integer("5"),
                    rpar=(cst.RightParen(),),
                ),
                "(-5)",
                parse_expression,
                CodeRange((1, 1), (1, 3)),
            ),
            (
                cst.UnaryOperation(
                    lpar=(cst.LeftParen(),),
                    operator=cst.Minus(),
                    expression=cst.Integer(
                        "5", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)
                    ),
                    rpar=(cst.RightParen(),),
                ),
                "(-(5))",
                parse_expression,
                CodeRange((1, 1), (1, 5)),
            ),
            (
                cst.UnaryOperation(
                    operator=cst.Minus(),
                    expression=cst.UnaryOperation(
                        operator=cst.Minus(), expression=cst.Integer("5")
                    ),
                ),
                "--5",
                parse_expression,
                CodeRange((1, 0), (1, 3)),
            ),
            # multiple nested parenthesis
            (
                cst.Integer(
                    "5",
                    lpar=(cst.LeftParen(), cst.LeftParen()),
                    rpar=(cst.RightParen(), cst.RightParen()),
                ),
                "((5))",
                parse_expression,
                CodeRange((1, 2), (1, 3)),
            ),
            (
                cst.UnaryOperation(
                    lpar=(cst.LeftParen(),),
                    operator=cst.Plus(),
                    expression=cst.Integer(
                        "5",
                        lpar=(cst.LeftParen(), cst.LeftParen()),
                        rpar=(cst.RightParen(), cst.RightParen()),
                    ),
                    rpar=(cst.RightParen(),),
                ),
                "(+((5)))",
                parse_expression,
                CodeRange((1, 1), (1, 7)),
            ),
        )
    )
    def test_valid(
        self,
        node: cst.CSTNode,
        code: str,
        parser: Optional[Callable[[str], cst.CSTNode]],
        position: Optional[CodeRange] = None,
    ) -> None:
        self.validate_node(node, code, parser, expected_position=position)

    @data_provider(
        (
            (
                lambda: cst.Integer("5", lpar=(cst.LeftParen(),)),
                "left paren without right paren",
            ),
            (
                lambda: cst.Integer("5", rpar=(cst.RightParen(),)),
                "right paren without left paren",
            ),
            (
                lambda: cst.Float("5.5", lpar=(cst.LeftParen(),)),
                "left paren without right paren",
            ),
            (
                lambda: cst.Float("5.5", rpar=(cst.RightParen(),)),
                "right paren without left paren",
            ),
            (
                lambda: cst.Imaginary("5i", lpar=(cst.LeftParen(),)),
                "left paren without right paren",
            ),
            (
                lambda: cst.Imaginary("5i", rpar=(cst.RightParen(),)),
                "right paren without left paren",
            ),
        )
    )
    def test_invalid(
        self, get_node: Callable[[], cst.CSTNode], expected_re: str
    ) -> None:
        self.assert_invalid(get_node, expected_re)
예제 #28
0
class NamedExprTest(CSTNodeTest):
    @data_provider((
        # Simple named expression
        {
            "node": cst.NamedExpr(cst.Name("x"), cst.Float("5.5")),
            "code": "x := 5.5",
            "parser":
            None,  # Walrus operator is illegal as top-level statement
            "expected_position": None,
        },
        # Parenthesized named expression
        {
            "node":
            cst.NamedExpr(
                lpar=(cst.LeftParen(), ),
                target=cst.Name("foo"),
                value=cst.Integer("5"),
                rpar=(cst.RightParen(), ),
            ),
            "code":
            "(foo := 5)",
            "parser":
            _parse_expression_force_38,
            "expected_position":
            CodeRange((1, 1), (1, 9)),
        },
        # Make sure that spacing works
        {
            "node":
            cst.NamedExpr(
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                target=cst.Name("foo"),
                whitespace_before_walrus=cst.SimpleWhitespace("  "),
                whitespace_after_walrus=cst.SimpleWhitespace("  "),
                value=cst.Name("bar"),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "code":
            "( foo  :=  bar )",
            "parser":
            _parse_expression_force_38,
            "expected_position":
            CodeRange((1, 2), (1, 14)),
        },
        # Make sure we can use these where allowed in if/while statements
        {
            "node":
            cst.While(
                test=cst.NamedExpr(
                    target=cst.Name(value="x"),
                    value=cst.Call(func=cst.Name(value="some_input")),
                ),
                body=cst.SimpleStatementSuite(body=[cst.Pass()]),
            ),
            "code":
            "while x := some_input(): pass\n",
            "parser":
            _parse_statement_force_38,
            "expected_position":
            None,
        },
        {
            "node":
            cst.If(
                test=cst.NamedExpr(
                    target=cst.Name(value="x"),
                    value=cst.Call(func=cst.Name(value="some_input")),
                ),
                body=cst.SimpleStatementSuite(body=[cst.Pass()]),
            ),
            "code":
            "if x := some_input(): pass\n",
            "parser":
            _parse_statement_force_38,
            "expected_position":
            None,
        },
        {
            "node":
            cst.If(
                test=cst.NamedExpr(
                    target=cst.Name(value="x"),
                    value=cst.Integer(value="1"),
                    whitespace_before_walrus=cst.SimpleWhitespace(""),
                    whitespace_after_walrus=cst.SimpleWhitespace(""),
                ),
                body=cst.SimpleStatementSuite(body=[cst.Pass()]),
            ),
            "code":
            "if x:=1: pass\n",
            "parser":
            _parse_statement_force_38,
            "expected_position":
            None,
        },
        # Function args
        {
            "node":
            cst.Call(
                func=cst.Name(value="f"),
                args=[
                    cst.Arg(value=cst.NamedExpr(
                        target=cst.Name(value="y"),
                        value=cst.Integer(value="1"),
                        whitespace_before_walrus=cst.SimpleWhitespace(""),
                        whitespace_after_walrus=cst.SimpleWhitespace(""),
                    )),
                ],
            ),
            "code":
            "f(y:=1)",
            "parser":
            _parse_expression_force_38,
            "expected_position":
            None,
        },
        # Whitespace handling on args is fragile
        {
            "node":
            cst.Call(
                func=cst.Name(value="f"),
                args=[
                    cst.Arg(
                        value=cst.Name(value="x"),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace("  ")),
                    ),
                    cst.Arg(
                        value=cst.NamedExpr(
                            target=cst.Name(value="y"),
                            value=cst.Integer(value="1"),
                            whitespace_before_walrus=cst.SimpleWhitespace(
                                "   "),
                            whitespace_after_walrus=cst.SimpleWhitespace(
                                "    "),
                        ),
                        whitespace_after_arg=cst.SimpleWhitespace("     "),
                    ),
                ],
            ),
            "code":
            "f(x,  y   :=    1     )",
            "parser":
            _parse_expression_force_38,
            "expected_position":
            None,
        },
        {
            "node":
            cst.Call(
                func=cst.Name(value="f"),
                args=[
                    cst.Arg(
                        value=cst.NamedExpr(
                            target=cst.Name(value="y"),
                            value=cst.Integer(value="1"),
                            whitespace_before_walrus=cst.SimpleWhitespace(
                                "   "),
                            whitespace_after_walrus=cst.SimpleWhitespace(
                                "    "),
                        ),
                        whitespace_after_arg=cst.SimpleWhitespace("     "),
                    ),
                ],
                whitespace_before_args=cst.SimpleWhitespace("  "),
            ),
            "code":
            "f(  y   :=    1     )",
            "parser":
            _parse_expression_force_38,
            "expected_position":
            None,
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider((
        {
            "get_node": (lambda: cst.NamedExpr(
                cst.Name("foo"), cst.Name("bar"), lpar=(cst.LeftParen(), ))),
            "expected_re":
            "left paren without right paren",
        },
        {
            "get_node": (lambda: cst.NamedExpr(
                cst.Name("foo"), cst.Name("bar"), rpar=(cst.RightParen(), ))),
            "expected_re":
            "right paren without left paren",
        },
    ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
예제 #29
0
def import_to_node_multi(imp: SortableImport,
                         module: cst.Module) -> cst.BaseStatement:
    body: List[cst.BaseSmallStatement] = []
    names: List[cst.ImportAlias] = []
    prev: Optional[cst.ImportAlias] = None
    following: List[str] = []
    lpar_lines: List[cst.EmptyLine] = []
    lpar_inline: cst.TrailingWhitespace = cst.TrailingWhitespace()

    item_count = len(imp.items)
    for idx, item in enumerate(imp.items):
        name = name_to_node(item.name)
        asname = cst.AsName(
            name=cst.Name(item.asname)) if item.asname else None

        # Leading comments actually have to be trailing comments on the previous node.
        # That means putting them on the lpar node for the first item
        if item.comments.before:
            lines = [
                cst.EmptyLine(
                    indent=True,
                    comment=cst.Comment(c),
                    whitespace=cst.SimpleWhitespace(module.default_indent),
                ) for c in item.comments.before
            ]
            if prev is None:
                lpar_lines.extend(lines)
            else:
                prev.comma.whitespace_after.empty_lines.extend(
                    lines)  # type: ignore

        # all items except the last needs whitespace to indent the *next* line/item
        indent = idx != (len(imp.items) - 1)

        first_line = cst.TrailingWhitespace()
        inline = COMMENT_INDENT.join(item.comments.inline)
        if inline:
            first_line = cst.TrailingWhitespace(
                whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
                comment=cst.Comment(inline),
            )

        if idx == item_count - 1:
            following = item.comments.following + imp.comments.final
        else:
            following = item.comments.following

        after = cst.ParenthesizedWhitespace(
            indent=True,
            first_line=first_line,
            empty_lines=[
                cst.EmptyLine(
                    indent=True,
                    comment=cst.Comment(c),
                    whitespace=cst.SimpleWhitespace(module.default_indent),
                ) for c in following
            ],
            last_line=cst.SimpleWhitespace(
                module.default_indent if indent else ""),
        )

        node = cst.ImportAlias(
            name=name,
            asname=asname,
            comma=cst.Comma(whitespace_after=after),
        )
        names.append(node)
        prev = node

    # from foo import (
    #     bar
    # )
    if imp.stem:
        stem, ndots = split_relative(imp.stem)
        if not stem:
            module_name = None
        else:
            module_name = name_to_node(stem)
        relative = (cst.Dot(), ) * ndots

        # inline comment following lparen
        if imp.comments.first_inline:
            inline = COMMENT_INDENT.join(imp.comments.first_inline)
            lpar_inline = cst.TrailingWhitespace(
                whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
                comment=cst.Comment(inline),
            )

        body = [
            cst.ImportFrom(
                module=module_name,
                names=names,
                relative=relative,
                lpar=cst.LeftParen(
                    whitespace_after=cst.ParenthesizedWhitespace(
                        indent=True,
                        first_line=lpar_inline,
                        empty_lines=lpar_lines,
                        last_line=cst.SimpleWhitespace(module.default_indent),
                    ), ),
                rpar=cst.RightParen(),
            )
        ]

    # import foo
    else:
        raise ValueError("can't render basic imports on multiple lines")

    # comment lines above import
    leading_lines = [
        cst.EmptyLine(indent=True, comment=cst.Comment(line))
        if line.startswith("#") else cst.EmptyLine(indent=False)
        for line in imp.comments.before
    ]

    # inline comments following import/rparen
    if imp.comments.last_inline:
        inline = COMMENT_INDENT.join(imp.comments.last_inline)
        trailing = cst.TrailingWhitespace(
            whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
            comment=cst.Comment(inline))
    else:
        trailing = cst.TrailingWhitespace()

    return cst.SimpleStatementLine(
        body=body,
        leading_lines=leading_lines,
        trailing_whitespace=trailing,
    )
예제 #30
0
class NonlocalConstructionTest(CSTNodeTest):
    @data_provider((
        # Single nonlocal statement
        {
            "node": cst.Nonlocal((cst.NameItem(cst.Name("a")), )),
            "code": "nonlocal a",
        },
        # Multiple entries in nonlocal statement
        {
            "node":
            cst.Nonlocal(
                (cst.NameItem(cst.Name("a")), cst.NameItem(cst.Name("b")))),
            "code":
            "nonlocal a, b",
            "expected_position":
            CodeRange((1, 0), (1, 13)),
        },
        # Whitespace rendering test
        {
            "node":
            cst.Nonlocal(
                (
                    cst.NameItem(
                        cst.Name("a"),
                        comma=cst.Comma(
                            whitespace_before=cst.SimpleWhitespace("  "),
                            whitespace_after=cst.SimpleWhitespace("  "),
                        ),
                    ),
                    cst.NameItem(cst.Name("b")),
                ),
                whitespace_after_nonlocal=cst.SimpleWhitespace("  "),
            ),
            "code":
            "nonlocal  a  ,  b",
            "expected_position":
            CodeRange((1, 0), (1, 17)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider((
        # Validate construction
        {
            "get_node": lambda: cst.Nonlocal(()),
            "expected_re":
            "A Nonlocal statement must have at least one NameItem",
        },
        # Validate whitespace handling
        {
            "get_node":
            lambda: cst.Nonlocal(
                (cst.NameItem(cst.Name("a")), ),
                whitespace_after_nonlocal=cst.SimpleWhitespace(""),
            ),
            "expected_re":
            "Must have at least one space after 'nonlocal' keyword",
        },
        # Validate comma handling
        {
            "get_node":
            lambda: cst.Nonlocal(
                (cst.NameItem(cst.Name("a"), comma=cst.Comma()), )),
            "expected_re":
            "The last NameItem in a Nonlocal cannot have a trailing comma",
        },
        # Validate paren handling
        {
            "get_node":
            lambda: cst.Nonlocal((cst.NameItem(
                cst.Name("a",
                         lpar=(cst.LeftParen(), ),
                         rpar=(cst.RightParen(), ))), )),
            "expected_re":
            "Cannot have parens around names in NameItem",
        },
    ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)