class ExperimentalReentrantCodegenProviderTest(UnitTest):
    @data_provider({
        "simple_top_level_statement": {
            "old_module": ("""\
                    import math
                    c = math.sqrt(a*a + b*b)
                    """),
            "new_module": ("""\
                    import math
                    c = math.hypot(a, b)
                    """),
            "old_node":
            lambda m: m.body[1],
            "new_node":
            cst.parse_statement("c = math.hypot(a, b)"),
        },
        "replacement_inside_block": {
            "old_module": ("""\
                    import math
                    def do_math(a, b):
                        c = math.sqrt(a*a + b*b)
                        return c
                    """),
            "new_module": ("""\
                    import math
                    def do_math(a, b):
                        c = math.hypot(a, b)
                        return c
                    """),
            "old_node":
            lambda m: m.body[1].body.body[0],
            "new_node":
            cst.parse_statement("c = math.hypot(a, b)"),
        },
        "missing_trailing_newline": {
            "old_module": "old_fn()",  # this module has no trailing newline
            "new_module": "new_fn()",
            "old_node": lambda m: m.body[0],
            "new_node": cst.parse_statement("new_fn()\n"),
        },
        "nested_blocks_with_missing_trailing_newline": {
            "old_module": (
                """\
                    if outer:
                        if inner:
                            old_fn()"""

                # this module has no trailing newline
            ),
            "new_module": ("""\
                    if outer:
                        if inner:
                            new_fn()"""),
            "old_node":
            lambda m: m.body[0].body.body[0].body.body[0],
            "new_node":
            cst.parse_statement("new_fn()\n"),
        },
    })
    def test_provider(
        self,
        old_module: str,
        new_module: str,
        old_node: Callable[[cst.Module], cst.CSTNode],
        new_node: cst.BaseStatement,
    ) -> None:
        old_module = dedent(old_module)
        new_module = dedent(new_module)

        mw = MetadataWrapper(cst.parse_module(old_module))
        codegen_partial = mw.resolve(ExperimentalReentrantCodegenProvider)[
            old_node(mw.module)]

        self.assertEqual(codegen_partial.get_original_module_code(),
                         old_module)
        self.assertEqual(codegen_partial.get_modified_module_code(new_node),
                         new_module)

    def test_byte_conversion(self, ) -> None:
        module_bytes = "fn()\n".encode("utf-16")
        mw = MetadataWrapper(
            cst.parse_module("fn()\n",
                             cst.PartialParserConfig(encoding="utf-16")))
        codegen_partial = mw.resolve(ExperimentalReentrantCodegenProvider)[
            mw.module.body[0]]
        self.assertEqual(codegen_partial.get_original_module_bytes(),
                         module_bytes)
        self.assertEqual(
            codegen_partial.get_modified_module_bytes(
                cst.parse_statement("fn2()\n")),
            "fn2()\n".encode("utf-16"),
        )
示例#2
0
    def leave_Module(self, original_node: libcst.Module,
                     updated_node: libcst.Module) -> libcst.Module:
        # Don't try to modify if we have nothing to do
        if (not self.module_imports and not self.module_mapping
                and not self.module_aliases and not self.alias_mapping):
            return updated_node

        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node)

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(
            statements_after_imports)

        # Mapping of modules we're adding to the object with and without alias they should import
        module_and_alias_mapping = defaultdict(list)
        for module, aliases in self.alias_mapping.items():
            module_and_alias_mapping[module].extend(aliases)
        for module, imports in self.module_mapping.items():
            module_and_alias_mapping[module].extend([(object, None)
                                                     for object in imports])
        module_and_alias_mapping = {
            module: sorted(aliases)
            for module, aliases in module_and_alias_mapping.items()
        }
        # import ptvsd; ptvsd.set_trace()
        # Now, add all of the imports we need!
        return updated_node.with_changes(body=(
            *[
                parse_statement(
                    f"from {module} import " + ", ".join([
                        obj if alias is None else f"{obj} as {alias}"
                        for (obj, alias) in aliases
                    ]),
                    config=updated_node.config_for_parsing,
                ) for module, aliases in module_and_alias_mapping.items()
                if module == "__future__"
            ],
            *statements_before_imports,
            *[
                parse_statement(f"import {module}",
                                config=updated_node.config_for_parsing)
                for module in sorted(self.module_imports)
            ],
            *[
                parse_statement(
                    f"import {module} as {asname}",
                    config=updated_node.config_for_parsing,
                ) for (module, asname) in self.module_aliases.items()
            ],
            *[
                parse_statement(
                    f"from {module} import " + ", ".join([
                        obj if alias is None else f"{obj} as {alias}"
                        for (obj, alias) in aliases
                    ]),
                    config=updated_node.config_for_parsing,
                ) for module, aliases in module_and_alias_mapping.items()
                if module != "__future__"
            ],
            *statements_after_imports,
        ))
示例#3
0
class ParseErrorsTest(UnitTest):
    @data_provider({
        # _wrapped_tokenize raises these exceptions
        "wrapped_tokenize__invalid_token": (
            lambda: cst.parse_module("'"),
            dedent("""
                    Syntax Error @ 1:1.
                    "'" is not a valid token.

                    '
                    ^
                    """).strip(),
        ),
        "wrapped_tokenize__expected_dedent": (
            lambda: cst.parse_module("if False:\n    pass\n  pass"),
            dedent("""
                    Syntax Error @ 3:1.
                    Inconsistent indentation. Expected a dedent.

                      pass
                    ^
                    """).strip(),
        ),
        "wrapped_tokenize__mismatched_braces": (
            lambda: cst.parse_module("abcd)"),
            dedent("""
                    Syntax Error @ 1:5.
                    Encountered a closing brace without a matching opening brace.

                    abcd)
                        ^
                    """).strip(),
        ),
        # _base_parser raises these exceptions
        "base_parser__unexpected_indent": (
            lambda: cst.parse_module("    abcd"),
            dedent("""
                    Syntax Error @ 1:5.
                    Incomplete input. Unexpectedly encountered an indent.

                        abcd
                        ^
                    """).strip(),
        ),
        "base_parser__unexpected_dedent": (
            lambda: cst.parse_module("if False:\n    (el for el\n"),
            dedent("""
                    Syntax Error @ 3:1.
                    Incomplete input. Encountered a dedent, but expected 'in'.

                        (el for el
                                  ^
                    """).strip(),
        ),
        "base_parser__multiple_possibilities": (
            lambda: cst.parse_module("try: pass"),
            dedent("""
                    Syntax Error @ 2:1.
                    Incomplete input. Encountered end of file (EOF), but expected 'except', or 'finally'.

                    try: pass
                             ^
                    """).strip(),
        ),
        # conversion functions raise these exceptions.
        # `_base_parser` is responsible for attaching location information.
        "convert_nonterminal__dict_unpacking": (
            lambda: cst.parse_expression("{**el for el in []}"),
            dedent("""
                    Syntax Error @ 1:19.
                    dict unpacking cannot be used in dict comprehension

                    {**el for el in []}
                                      ^
                    """).strip(),
        ),
        "convert_nonterminal__arglist_non_default_after_default": (
            lambda: cst.parse_statement("def fn(first=None, second): ..."),
            dedent("""
                    Syntax Error @ 1:26.
                    Cannot have a non-default argument following a default argument.

                    def fn(first=None, second): ...
                                             ^
                    """).strip(),
        ),
        "convert_nonterminal__arglist_trailing_param_star_without_comma": (
            lambda: cst.parse_statement("def fn(abc, *): ..."),
            dedent("""
                    Syntax Error @ 1:14.
                    Named (keyword) arguments must follow a bare *.

                    def fn(abc, *): ...
                                 ^
                    """).strip(),
        ),
        "convert_nonterminal__arglist_trailing_param_star_with_comma": (
            lambda: cst.parse_statement("def fn(abc, *,): ..."),
            dedent("""
                    Syntax Error @ 1:15.
                    Named (keyword) arguments must follow a bare *.

                    def fn(abc, *,): ...
                                  ^
                    """).strip(),
        ),
        "convert_nonterminal__class_arg_positional_after_keyword": (
            lambda: cst.parse_statement("class Cls(first=None, second): ..."),
            dedent("""
                    Syntax Error @ 2:1.
                    Positional argument follows keyword argument.

                    class Cls(first=None, second): ...
                                                      ^
                    """).strip(),
        ),
        "convert_nonterminal__class_arg_positional_expansion_after_keyword": (
            lambda: cst.parse_statement("class Cls(first=None, *second): ..."),
            dedent("""
                    Syntax Error @ 2:1.
                    Positional argument follows keyword argument.

                    class Cls(first=None, *second): ...
                                                       ^
                    """).strip(),
        ),
    })
    def test_parser_syntax_error_str(self, parse_fn: Callable[[], object],
                                     expected: str) -> None:
        with self.assertRaises(cst.ParserSyntaxError) as cm:
            parse_fn()
        if not is_native():
            self.assertEqual(str(cm.exception), expected)

    def test_native_fallible_into_py(self) -> None:
        with patch(
                "libcst._nodes.expression.Name._validate") as await_validate:
            await_validate.side_effect = CSTValidationError(
                "validate is broken")
            with self.assertRaises(Exception) as e:
                cst.parse_module("foo")
            self.assertIsInstance(e.exception,
                                  (SyntaxError, cst.ParserSyntaxError))
示例#4
0
class ForTest(CSTNodeTest):
    @data_provider((
        # Simple for block
        {
            "node":
            cst.For(
                cst.Name("target"),
                cst.Call(cst.Name("iter")),
                cst.SimpleStatementSuite((cst.Pass(), )),
            ),
            "code":
            "for target in iter(): pass\n",
            "parser":
            parse_statement,
        },
        # Simple async for block
        {
            "node":
            cst.For(
                cst.Name("target"),
                cst.Call(cst.Name("iter")),
                cst.SimpleStatementSuite((cst.Pass(), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async for target in iter(): pass\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.7")),
        },
        # Python 3.6 async for block
        {
            "node":
            cst.FunctionDef(
                cst.Name("foo"),
                cst.Parameters(),
                cst.IndentedBlock((cst.For(
                    cst.Name("target"),
                    cst.Call(cst.Name("iter")),
                    cst.SimpleStatementSuite((cst.Pass(), )),
                    asynchronous=cst.Asynchronous(),
                ), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async def foo():\n    async for target in iter(): pass\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.6")),
        },
        # For block with else
        {
            "node":
            cst.For(
                cst.Name("target"),
                cst.Call(cst.Name("iter")),
                cst.SimpleStatementSuite((cst.Pass(), )),
                cst.Else(cst.SimpleStatementSuite((cst.Pass(), ))),
            ),
            "code":
            "for target in iter(): pass\nelse: pass\n",
            "parser":
            parse_statement,
        },
        # indentation
        {
            "node":
            DummyIndentedBlock(
                "    ",
                cst.For(
                    cst.Name("target"),
                    cst.Call(cst.Name("iter")),
                    cst.SimpleStatementSuite((cst.Pass(), )),
                ),
            ),
            "code":
            "    for target in iter(): pass\n",
            "parser":
            None,
        },
        # for an indented body
        {
            "node":
            DummyIndentedBlock(
                "    ",
                cst.For(
                    cst.Name("target"),
                    cst.Call(cst.Name("iter")),
                    cst.IndentedBlock((cst.SimpleStatementLine(
                        (cst.Pass(), )), )),
                ),
            ),
            "code":
            "    for target in iter():\n        pass\n",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 4), (2, 12)),
        },
        # leading_lines
        {
            "node":
            cst.For(
                cst.Name("target"),
                cst.Call(cst.Name("iter")),
                cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )),
                cst.Else(
                    cst.IndentedBlock((cst.SimpleStatementLine(
                        (cst.Pass(), )), )),
                    leading_lines=(cst.EmptyLine(
                        comment=cst.Comment("# else comment")), ),
                ),
                leading_lines=(cst.EmptyLine(
                    comment=cst.Comment("# leading comment")), ),
            ),
            "code":
            "# leading comment\nfor target in iter():\n    pass\n# else comment\nelse:\n    pass\n",
            "parser":
            None,
            "expected_position":
            CodeRange((2, 0), (6, 8)),
        },
        # Weird spacing rules
        {
            "node":
            cst.For(
                cst.Name("target",
                         lpar=(cst.LeftParen(), ),
                         rpar=(cst.RightParen(), )),
                cst.Call(
                    cst.Name("iter"),
                    lpar=(cst.LeftParen(), ),
                    rpar=(cst.RightParen(), ),
                ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_for=cst.SimpleWhitespace(""),
                whitespace_before_in=cst.SimpleWhitespace(""),
                whitespace_after_in=cst.SimpleWhitespace(""),
            ),
            "code":
            "for(target)in(iter()): pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 27)),
        },
        # Whitespace
        {
            "node":
            cst.For(
                cst.Name("target"),
                cst.Call(cst.Name("iter")),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_for=cst.SimpleWhitespace("  "),
                whitespace_before_in=cst.SimpleWhitespace("  "),
                whitespace_after_in=cst.SimpleWhitespace("  "),
                whitespace_before_colon=cst.SimpleWhitespace("  "),
            ),
            "code":
            "for  target  in  iter()  : pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 31)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider((
        {
            "get_node":
            lambda: cst.For(
                cst.Name("target"),
                cst.Call(cst.Name("iter")),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_for=cst.SimpleWhitespace(""),
            ),
            "expected_re":
            "Must have at least one space after 'for' keyword",
        },
        {
            "get_node":
            lambda: cst.For(
                cst.Name("target"),
                cst.Call(cst.Name("iter")),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_before_in=cst.SimpleWhitespace(""),
            ),
            "expected_re":
            "Must have at least one space before 'in' keyword",
        },
        {
            "get_node":
            lambda: cst.For(
                cst.Name("target"),
                cst.Call(cst.Name("iter")),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_in=cst.SimpleWhitespace(""),
            ),
            "expected_re":
            "Must have at least one space after 'in' keyword",
        },
    ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
示例#5
0
 def test_valid(self, **kwargs: Any) -> None:
     self.validate_node(
         parser=lambda code: ensure_type(parse_statement(code), cst.
                                         SimpleStatementLine).body[0],
         **kwargs,
     )
示例#6
0
 def inner(code: str) -> cst.BaseStatement:
     return cst.parse_statement(code,
                                config=cst.PartialParserConfig(**config))
示例#7
0
def _assert_parser(code: str) -> cst.Assert:
    return ensure_type(
        ensure_type(parse_statement(code), cst.SimpleStatementLine).body[0],
        cst.Assert)
示例#8
0
class SimpleCompTest(CSTNodeTest):
    @data_provider([
        # simple GeneratorExp
        {
            "node":
            cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(target=cst.Name("b"), iter=cst.Name("c"))),
            "code":
            "(a for b in c)",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 1), (1, 13)),
        },
        # simple ListComp
        {
            "node":
            cst.ListComp(cst.Name("a"),
                         cst.CompFor(target=cst.Name("b"),
                                     iter=cst.Name("c"))),
            "code":
            "[a for b in c]",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 0), (1, 14)),
        },
        # simple SetComp
        {
            "node":
            cst.SetComp(cst.Name("a"),
                        cst.CompFor(target=cst.Name("b"), iter=cst.Name("c"))),
            "code":
            "{a for b in c}",
            "parser":
            parse_expression,
        },
        # async GeneratorExp
        {
            "node":
            cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    target=cst.Name("b"),
                    iter=cst.Name("c"),
                    asynchronous=cst.Asynchronous(),
                ),
            ),
            "code":
            "(a async for b in c)",
            "parser":
            lambda code: parse_expression(
                code, config=PartialParserConfig(python_version="3.7")),
        },
        # Python 3.6 async GeneratorExp
        {
            "node":
            cst.FunctionDef(
                cst.Name("foo"),
                cst.Parameters(),
                cst.IndentedBlock((cst.SimpleStatementLine((cst.Expr(
                    cst.GeneratorExp(
                        cst.Name("a"),
                        cst.CompFor(
                            target=cst.Name("b"),
                            iter=cst.Name("c"),
                            asynchronous=cst.Asynchronous(),
                        ),
                    )), )), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async def foo():\n    (a async for b in c)\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.6")),
        },
        # a generator doesn't have to own it's own parenthesis
        {
            "node":
            cst.Call(
                cst.Name("func"),
                [
                    cst.Arg(
                        cst.GeneratorExp(
                            cst.Name("a"),
                            cst.CompFor(target=cst.Name("b"),
                                        iter=cst.Name("c")),
                            lpar=[],
                            rpar=[],
                        ))
                ],
            ),
            "code":
            "func(a for b in c)",
            "parser":
            parse_expression,
        },
        # add a few 'if' clauses
        {
            "node":
            cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    target=cst.Name("b"),
                    iter=cst.Name("c"),
                    ifs=[
                        cst.CompIf(cst.Name("d")),
                        cst.CompIf(cst.Name("e")),
                        cst.CompIf(cst.Name("f")),
                    ],
                ),
            ),
            "code":
            "(a for b in c if d if e if f)",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 1), (1, 28)),
        },
        # nested/inner for-in clause
        {
            "node":
            cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    target=cst.Name("b"),
                    iter=cst.Name("c"),
                    inner_for_in=cst.CompFor(target=cst.Name("d"),
                                             iter=cst.Name("e")),
                ),
            ),
            "code":
            "(a for b in c for d in e)",
            "parser":
            parse_expression,
        },
        # nested/inner for-in clause with an 'if' clause
        {
            "node":
            cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    target=cst.Name("b"),
                    iter=cst.Name("c"),
                    ifs=[cst.CompIf(cst.Name("d"))],
                    inner_for_in=cst.CompFor(target=cst.Name("e"),
                                             iter=cst.Name("f")),
                ),
            ),
            "code":
            "(a for b in c if d for e in f)",
            "parser":
            parse_expression,
        },
        # custom whitespace
        {
            "node":
            cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    target=cst.Name("b"),
                    iter=cst.Name("c"),
                    ifs=[
                        cst.CompIf(
                            cst.Name("d"),
                            whitespace_before=cst.SimpleWhitespace("\t"),
                            whitespace_before_test=cst.SimpleWhitespace(
                                "\t\t"),
                        )
                    ],
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after_for=cst.SimpleWhitespace("   "),
                    whitespace_before_in=cst.SimpleWhitespace("    "),
                    whitespace_after_in=cst.SimpleWhitespace("     "),
                ),
                lpar=[
                    cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f"))
                ],
                rpar=[
                    cst.RightParen(
                        whitespace_before=cst.SimpleWhitespace("\f\f"))
                ],
            ),
            "code":
            "(\fa  for   b    in     c\tif\t\td\f\f)",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 2), (1, 30)),
        },
        # custom whitespace around ListComp's brackets
        {
            "node":
            cst.ListComp(
                cst.Name("a"),
                cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
                lbracket=cst.LeftSquareBracket(
                    whitespace_after=cst.SimpleWhitespace("\t")),
                rbracket=cst.RightSquareBracket(
                    whitespace_before=cst.SimpleWhitespace("\t\t")),
                lpar=[
                    cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f"))
                ],
                rpar=[
                    cst.RightParen(
                        whitespace_before=cst.SimpleWhitespace("\f\f"))
                ],
            ),
            "code":
            "(\f[\ta for b in c\t\t]\f\f)",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 2), (1, 19)),
        },
        # custom whitespace around SetComp's braces
        {
            "node":
            cst.SetComp(
                cst.Name("a"),
                cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
                lbrace=cst.LeftCurlyBrace(
                    whitespace_after=cst.SimpleWhitespace("\t")),
                rbrace=cst.RightCurlyBrace(
                    whitespace_before=cst.SimpleWhitespace("\t\t")),
                lpar=[
                    cst.LeftParen(whitespace_after=cst.SimpleWhitespace("\f"))
                ],
                rpar=[
                    cst.RightParen(
                        whitespace_before=cst.SimpleWhitespace("\f\f"))
                ],
            ),
            "code":
            "(\f{\ta for b in c\t\t}\f\f)",
            "parser":
            parse_expression,
        },
        # no whitespace between elements
        {
            "node":
            cst.GeneratorExp(
                cst.Name("a", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]),
                cst.CompFor(
                    target=cst.Name("b",
                                    lpar=[cst.LeftParen()],
                                    rpar=[cst.RightParen()]),
                    iter=cst.Name("c",
                                  lpar=[cst.LeftParen()],
                                  rpar=[cst.RightParen()]),
                    ifs=[
                        cst.CompIf(
                            cst.Name("d",
                                     lpar=[cst.LeftParen()],
                                     rpar=[cst.RightParen()]),
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_before_test=cst.SimpleWhitespace(""),
                        )
                    ],
                    inner_for_in=cst.CompFor(
                        target=cst.Name("e",
                                        lpar=[cst.LeftParen()],
                                        rpar=[cst.RightParen()]),
                        iter=cst.Name("f",
                                      lpar=[cst.LeftParen()],
                                      rpar=[cst.RightParen()]),
                        whitespace_before=cst.SimpleWhitespace(""),
                        whitespace_after_for=cst.SimpleWhitespace(""),
                        whitespace_before_in=cst.SimpleWhitespace(""),
                        whitespace_after_in=cst.SimpleWhitespace(""),
                    ),
                    whitespace_before=cst.SimpleWhitespace(""),
                    whitespace_after_for=cst.SimpleWhitespace(""),
                    whitespace_before_in=cst.SimpleWhitespace(""),
                    whitespace_after_in=cst.SimpleWhitespace(""),
                ),
                lpar=[cst.LeftParen()],
                rpar=[cst.RightParen()],
            ),
            "code":
            "((a)for(b)in(c)if(d)for(e)in(f))",
            "parser":
            parse_expression,
            "expected_position":
            CodeRange((1, 1), (1, 31)),
        },
        # no whitespace before/after GeneratorExp is valid
        {
            "node":
            cst.Comparison(
                cst.GeneratorExp(
                    cst.Name("a"),
                    cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
                ),
                [
                    cst.ComparisonTarget(
                        cst.Is(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(""),
                        ),
                        cst.GeneratorExp(
                            cst.Name("d"),
                            cst.CompFor(target=cst.Name("e"),
                                        iter=cst.Name("f")),
                        ),
                    )
                ],
            ),
            "code":
            "(a for b in c)is(d for e in f)",
            "parser":
            parse_expression,
        },
        # no whitespace before/after ListComp is valid
        {
            "node":
            cst.Comparison(
                cst.ListComp(
                    cst.Name("a"),
                    cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
                ),
                [
                    cst.ComparisonTarget(
                        cst.Is(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(""),
                        ),
                        cst.ListComp(
                            cst.Name("d"),
                            cst.CompFor(target=cst.Name("e"),
                                        iter=cst.Name("f")),
                        ),
                    )
                ],
            ),
            "code":
            "[a for b in c]is[d for e in f]",
            "parser":
            parse_expression,
        },
        # no whitespace before/after SetComp is valid
        {
            "node":
            cst.Comparison(
                cst.SetComp(
                    cst.Name("a"),
                    cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
                ),
                [
                    cst.ComparisonTarget(
                        cst.Is(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(""),
                        ),
                        cst.SetComp(
                            cst.Name("d"),
                            cst.CompFor(target=cst.Name("e"),
                                        iter=cst.Name("f")),
                        ),
                    )
                ],
            ),
            "code":
            "{a for b in c}is{d for e in f}",
            "parser":
            parse_expression,
        },
    ])
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider((
        (
            lambda: cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
                lpar=[cst.LeftParen(), cst.LeftParen()],
                rpar=[cst.RightParen()],
            ),
            "unbalanced parens",
        ),
        (
            lambda: cst.ListComp(
                cst.Name("a"),
                cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
                lpar=[cst.LeftParen(), cst.LeftParen()],
                rpar=[cst.RightParen()],
            ),
            "unbalanced parens",
        ),
        (
            lambda: cst.SetComp(
                cst.Name("a"),
                cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
                lpar=[cst.LeftParen(), cst.LeftParen()],
                rpar=[cst.RightParen()],
            ),
            "unbalanced parens",
        ),
        (
            lambda: cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    target=cst.Name("b"),
                    iter=cst.Name("c"),
                    whitespace_before=cst.SimpleWhitespace(""),
                ),
            ),
            "Must have at least one space before 'for' keyword.",
        ),
        (
            lambda: cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    target=cst.Name("b"),
                    iter=cst.Name("c"),
                    asynchronous=cst.Asynchronous(),
                    whitespace_before=cst.SimpleWhitespace(""),
                ),
            ),
            "Must have at least one space before 'async' keyword.",
        ),
        (
            lambda: cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    target=cst.Name("b"),
                    iter=cst.Name("c"),
                    whitespace_after_for=cst.SimpleWhitespace(""),
                ),
            ),
            "Must have at least one space after 'for' keyword.",
        ),
        (
            lambda: cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    target=cst.Name("b"),
                    iter=cst.Name("c"),
                    whitespace_before_in=cst.SimpleWhitespace(""),
                ),
            ),
            "Must have at least one space before 'in' keyword.",
        ),
        (
            lambda: cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    target=cst.Name("b"),
                    iter=cst.Name("c"),
                    whitespace_after_in=cst.SimpleWhitespace(""),
                ),
            ),
            "Must have at least one space after 'in' keyword.",
        ),
        (
            lambda: cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    target=cst.Name("b"),
                    iter=cst.Name("c"),
                    ifs=[
                        cst.CompIf(
                            cst.Name("d"),
                            whitespace_before=cst.SimpleWhitespace(""),
                        )
                    ],
                ),
            ),
            "Must have at least one space before 'if' keyword.",
        ),
        (
            lambda: cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    target=cst.Name("b"),
                    iter=cst.Name("c"),
                    ifs=[
                        cst.CompIf(
                            cst.Name("d"),
                            whitespace_before_test=cst.SimpleWhitespace(""),
                        )
                    ],
                ),
            ),
            "Must have at least one space after 'if' keyword.",
        ),
        (
            lambda: cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    target=cst.Name("b"),
                    iter=cst.Name("c"),
                    inner_for_in=cst.CompFor(
                        target=cst.Name("d"),
                        iter=cst.Name("e"),
                        whitespace_before=cst.SimpleWhitespace(""),
                    ),
                ),
            ),
            "Must have at least one space before 'for' keyword.",
        ),
        (
            lambda: cst.GeneratorExp(
                cst.Name("a"),
                cst.CompFor(
                    target=cst.Name("b"),
                    iter=cst.Name("c"),
                    inner_for_in=cst.CompFor(
                        target=cst.Name("d"),
                        iter=cst.Name("e"),
                        asynchronous=cst.Asynchronous(),
                        whitespace_before=cst.SimpleWhitespace(""),
                    ),
                ),
            ),
            "Must have at least one space before 'async' keyword.",
        ),
    ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)
示例#9
0
class AwaitTest(CSTNodeTest):
    @data_provider((
        # Some simple calls
        {
            "node":
            cst.Await(cst.Name("test")),
            "code":
            "await test",
            "parser":
            lambda code: parse_expression(
                code, config=PartialParserConfig(python_version="3.7")),
            "expected_position":
            None,
        },
        {
            "node":
            cst.Await(cst.Call(cst.Name("test"))),
            "code":
            "await test()",
            "parser":
            lambda code: parse_expression(
                code, config=PartialParserConfig(python_version="3.7")),
            "expected_position":
            None,
        },
        # Whitespace
        {
            "node":
            cst.Await(
                cst.Name("test"),
                whitespace_after_await=cst.SimpleWhitespace("  "),
                lpar=(cst.LeftParen(
                    whitespace_after=cst.SimpleWhitespace(" ")), ),
                rpar=(cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")), ),
            ),
            "code":
            "( await  test )",
            "parser":
            lambda code: parse_expression(
                code, config=PartialParserConfig(python_version="3.7")),
            "expected_position":
            CodeRange((1, 2), (1, 13)),
        },
    ))
    def test_valid_py37(self, **kwargs: Any) -> None:
        # We don't have sentinel nodes for atoms, so we know that 100% of atoms
        # can be parsed identically to their creation.
        self.validate_node(**kwargs)

    @data_provider((
        # Some simple calls
        {
            "node":
            cst.FunctionDef(
                cst.Name("foo"),
                cst.Parameters(),
                cst.IndentedBlock((cst.SimpleStatementLine(
                    (cst.Expr(cst.Await(cst.Name("test"))), )), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async def foo():\n    await test\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.6")),
            "expected_position":
            None,
        },
        {
            "node":
            cst.FunctionDef(
                cst.Name("foo"),
                cst.Parameters(),
                cst.IndentedBlock((cst.SimpleStatementLine(
                    (cst.Expr(cst.Await(cst.Call(cst.Name("test")))), )), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async def foo():\n    await test()\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.6")),
            "expected_position":
            None,
        },
        # Whitespace
        {
            "node":
            cst.FunctionDef(
                cst.Name("foo"),
                cst.Parameters(),
                cst.IndentedBlock((cst.SimpleStatementLine((cst.Expr(
                    cst.Await(
                        cst.Name("test"),
                        whitespace_after_await=cst.SimpleWhitespace("  "),
                        lpar=(cst.LeftParen(
                            whitespace_after=cst.SimpleWhitespace(" ")), ),
                        rpar=(cst.RightParen(
                            whitespace_before=cst.SimpleWhitespace(" ")), ),
                    )), )), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async def foo():\n    ( await  test )\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.6")),
            "expected_position":
            None,
        },
    ))
    def test_valid_py36(self, **kwargs: Any) -> None:
        # We don't have sentinel nodes for atoms, so we know that 100% of atoms
        # can be parsed identically to their creation.
        self.validate_node(**kwargs)

    @data_provider((
        # Expression wrapping parenthesis rules
        {
            "get_node":
            (lambda: cst.Await(cst.Name("foo"), lpar=(cst.LeftParen(), ))),
            "expected_re":
            "left paren without right paren",
        },
        {
            "get_node":
            (lambda: cst.Await(cst.Name("foo"), rpar=(cst.RightParen(), ))),
            "expected_re":
            "right paren without left paren",
        },
        {
            "get_node":
            (lambda: cst.Await(cst.Name("foo"),
                               whitespace_after_await=cst.SimpleWhitespace(""))
             ),
            "expected_re":
            "at least one space after await",
        },
    ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
示例#10
0
 def runtest(self):
     code, query, expected = self.spec
     node = libcst.parse_statement(code).body[0]
     assert m.matches(node, create_matcher(query)) == expected