Esempio n. 1
0
    def leave_ImportFrom(self, original_node: libcst.ImportFrom,
                         updated_node: libcst.ImportFrom) -> libcst.ImportFrom:
        if isinstance(updated_node.names, libcst.ImportStar):
            # There's nothing to do here!
            return updated_node

        # Get the module we're importing as a string, see if we have work to do.
        module = get_absolute_module_for_import(self.context.full_module_name,
                                                updated_node)
        if (module is None or module not in self.module_mapping
                and module not in self.alias_mapping):
            return updated_node

        # We have work to do, mark that we won't modify this again.
        imports_to_add = self.module_mapping.get(module, [])
        if module in self.module_mapping:
            del self.module_mapping[module]
        aliases_to_add = self.alias_mapping.get(module, [])
        if module in self.alias_mapping:
            del self.alias_mapping[module]

        # Now, do the actual update.
        return updated_node.with_changes(names=[
            *(libcst.ImportAlias(name=libcst.Name(imp))
              for imp in sorted(imports_to_add)),
            *(libcst.ImportAlias(
                name=libcst.Name(imp),
                asname=libcst.AsName(name=libcst.Name(alias)),
            ) for (imp, alias) in sorted(aliases_to_add)),
            *updated_node.names,
        ])
Esempio n. 2
0
def import_to_node_single(imp: SortableImport,
                          module: cst.Module) -> cst.BaseStatement:
    leading_lines = [
        cst.EmptyLine(indent=True, comment=cst.Comment(line))
        if line.startswith("#") else cst.EmptyLine(indent=False)
        for line in imp.comments.before
    ]

    trailing_whitespace = cst.TrailingWhitespace()
    trailing_comments = list(imp.comments.first_inline)

    names: List[cst.ImportAlias] = []
    for item in imp.items:
        name = name_to_node(item.name)
        asname = cst.AsName(
            name=cst.Name(item.asname)) if item.asname else None
        node = cst.ImportAlias(name=name, asname=asname)
        names.append(node)
        trailing_comments += item.comments.before
        trailing_comments += item.comments.inline
        trailing_comments += item.comments.following

    trailing_comments += imp.comments.final
    trailing_comments += imp.comments.last_inline
    if trailing_comments:
        text = COMMENT_INDENT.join(trailing_comments)
        trailing_whitespace = cst.TrailingWhitespace(
            whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
            comment=cst.Comment(text))

    if imp.stem:
        stem, ndots = split_relative(imp.stem)
        if not stem:
            module_name = None
        else:
            module_name = name_to_node(stem)
        relative = (cst.Dot(), ) * ndots

        line = cst.SimpleStatementLine(
            body=[
                cst.ImportFrom(module=module_name,
                               names=names,
                               relative=relative)
            ],
            leading_lines=leading_lines,
            trailing_whitespace=trailing_whitespace,
        )

    else:
        line = cst.SimpleStatementLine(
            body=[cst.Import(names=names)],
            leading_lines=leading_lines,
            trailing_whitespace=trailing_whitespace,
        )

    return line
Esempio n. 3
0
    def leave_ImportFrom(self, original_node: libcst.ImportFrom,
                         updated_node: libcst.ImportFrom) -> libcst.ImportFrom:
        if len(updated_node.relative) > 0 or updated_node.module is None:
            # Don't support relative-only imports at the moment.
            return updated_node
        if updated_node.names == "*":
            # There's nothing to do here!
            return updated_node

        # Get the module we're importing as a string, see if we have work to do
        module = self._get_string_name(updated_node.module)
        if module not in self.module_mapping and module not in self.alias_mapping:
            return updated_node

        # We have work to do, mark that we won't modify this again.
        imports_to_add = self.module_mapping.get(module, [])
        if module in self.module_mapping:
            del self.module_mapping[module]
        aliases_to_add = self.alias_mapping.get(module, [])
        if module in self.alias_mapping:
            del self.alias_mapping[module]

        # Now, do the actual update.
        return updated_node.with_changes(names=(
            *[
                libcst.ImportAlias(name=libcst.Name(imp))
                for imp in imports_to_add
            ],
            *[
                libcst.ImportAlias(
                    name=libcst.Name(imp),
                    asname=libcst.AsName(name=libcst.Name(alias)),
                ) for (imp, alias) in aliases_to_add
            ],
            *updated_node.names,
        ))
Esempio n. 4
0
def import_to_node_multi(imp: SortableImport,
                         module: cst.Module) -> cst.BaseStatement:
    body: List[cst.BaseSmallStatement] = []
    names: List[cst.ImportAlias] = []
    prev: Optional[cst.ImportAlias] = None
    following: List[str] = []
    lpar_lines: List[cst.EmptyLine] = []
    lpar_inline: cst.TrailingWhitespace = cst.TrailingWhitespace()

    item_count = len(imp.items)
    for idx, item in enumerate(imp.items):
        name = name_to_node(item.name)
        asname = cst.AsName(
            name=cst.Name(item.asname)) if item.asname else None

        # Leading comments actually have to be trailing comments on the previous node.
        # That means putting them on the lpar node for the first item
        if item.comments.before:
            lines = [
                cst.EmptyLine(
                    indent=True,
                    comment=cst.Comment(c),
                    whitespace=cst.SimpleWhitespace(module.default_indent),
                ) for c in item.comments.before
            ]
            if prev is None:
                lpar_lines.extend(lines)
            else:
                prev.comma.whitespace_after.empty_lines.extend(
                    lines)  # type: ignore

        # all items except the last needs whitespace to indent the *next* line/item
        indent = idx != (len(imp.items) - 1)

        first_line = cst.TrailingWhitespace()
        inline = COMMENT_INDENT.join(item.comments.inline)
        if inline:
            first_line = cst.TrailingWhitespace(
                whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
                comment=cst.Comment(inline),
            )

        if idx == item_count - 1:
            following = item.comments.following + imp.comments.final
        else:
            following = item.comments.following

        after = cst.ParenthesizedWhitespace(
            indent=True,
            first_line=first_line,
            empty_lines=[
                cst.EmptyLine(
                    indent=True,
                    comment=cst.Comment(c),
                    whitespace=cst.SimpleWhitespace(module.default_indent),
                ) for c in following
            ],
            last_line=cst.SimpleWhitespace(
                module.default_indent if indent else ""),
        )

        node = cst.ImportAlias(
            name=name,
            asname=asname,
            comma=cst.Comma(whitespace_after=after),
        )
        names.append(node)
        prev = node

    # from foo import (
    #     bar
    # )
    if imp.stem:
        stem, ndots = split_relative(imp.stem)
        if not stem:
            module_name = None
        else:
            module_name = name_to_node(stem)
        relative = (cst.Dot(), ) * ndots

        # inline comment following lparen
        if imp.comments.first_inline:
            inline = COMMENT_INDENT.join(imp.comments.first_inline)
            lpar_inline = cst.TrailingWhitespace(
                whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
                comment=cst.Comment(inline),
            )

        body = [
            cst.ImportFrom(
                module=module_name,
                names=names,
                relative=relative,
                lpar=cst.LeftParen(
                    whitespace_after=cst.ParenthesizedWhitespace(
                        indent=True,
                        first_line=lpar_inline,
                        empty_lines=lpar_lines,
                        last_line=cst.SimpleWhitespace(module.default_indent),
                    ), ),
                rpar=cst.RightParen(),
            )
        ]

    # import foo
    else:
        raise ValueError("can't render basic imports on multiple lines")

    # comment lines above import
    leading_lines = [
        cst.EmptyLine(indent=True, comment=cst.Comment(line))
        if line.startswith("#") else cst.EmptyLine(indent=False)
        for line in imp.comments.before
    ]

    # inline comments following import/rparen
    if imp.comments.last_inline:
        inline = COMMENT_INDENT.join(imp.comments.last_inline)
        trailing = cst.TrailingWhitespace(
            whitespace=cst.SimpleWhitespace(COMMENT_INDENT),
            comment=cst.Comment(inline))
    else:
        trailing = cst.TrailingWhitespace()

    return cst.SimpleStatementLine(
        body=body,
        leading_lines=leading_lines,
        trailing_whitespace=trailing,
    )
Esempio n. 5
0
class TryTest(CSTNodeTest):
    @data_provider(
        (
            # Simple try/except block
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                        ),
                    ),
                ),
                "code": "try: pass\nexcept: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (2, 12)),
            },
            # Try/except with a class
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            type=cst.Name("Exception"),
                        ),
                    ),
                ),
                "code": "try: pass\nexcept Exception: pass\n",
                "parser": parse_statement,
            },
            # Try/except with a named class
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            type=cst.Name("Exception"),
                            name=cst.AsName(cst.Name("exc")),
                        ),
                    ),
                ),
                "code": "try: pass\nexcept Exception as exc: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (2, 29)),
            },
            # Try/except with multiple clauses
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            type=cst.Name("TypeError"),
                            name=cst.AsName(cst.Name("e")),
                        ),
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            type=cst.Name("KeyError"),
                            name=cst.AsName(cst.Name("e")),
                        ),
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                        ),
                    ),
                ),
                "code": "try: pass\n"
                + "except TypeError as e: pass\n"
                + "except KeyError as e: pass\n"
                + "except: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (4, 12)),
            },
            # Simple try/finally block
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
                ),
                "code": "try: pass\nfinally: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (2, 13)),
            },
            # Simple try/except/finally block
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                        ),
                    ),
                    finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
                ),
                "code": "try: pass\nexcept: pass\nfinally: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (3, 13)),
            },
            # Simple try/except/else block
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                        ),
                    ),
                    orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
                ),
                "code": "try: pass\nexcept: pass\nelse: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (3, 10)),
            },
            # Simple try/except/else block/finally
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                        ),
                    ),
                    orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
                    finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
                ),
                "code": "try: pass\nexcept: pass\nelse: pass\nfinally: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (4, 13)),
            },
            # Verify whitespace in various locations
            {
                "node": cst.Try(
                    leading_lines=(cst.EmptyLine(comment=cst.Comment("# 1")),),
                    body=cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            leading_lines=(cst.EmptyLine(comment=cst.Comment("# 2")),),
                            type=cst.Name("TypeError"),
                            name=cst.AsName(
                                cst.Name("e"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                            whitespace_after_except=cst.SimpleWhitespace("  "),
                            whitespace_before_colon=cst.SimpleWhitespace(" "),
                            body=cst.SimpleStatementSuite((cst.Pass(),)),
                        ),
                    ),
                    orelse=cst.Else(
                        leading_lines=(cst.EmptyLine(comment=cst.Comment("# 3")),),
                        body=cst.SimpleStatementSuite((cst.Pass(),)),
                        whitespace_before_colon=cst.SimpleWhitespace(" "),
                    ),
                    finalbody=cst.Finally(
                        leading_lines=(cst.EmptyLine(comment=cst.Comment("# 4")),),
                        body=cst.SimpleStatementSuite((cst.Pass(),)),
                        whitespace_before_colon=cst.SimpleWhitespace(" "),
                    ),
                    whitespace_before_colon=cst.SimpleWhitespace(" "),
                ),
                "code": "# 1\ntry : pass\n# 2\nexcept  TypeError  as  e : pass\n# 3\nelse : pass\n# 4\nfinally : pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((2, 0), (8, 14)),
            },
            # Please don't write code like this
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            type=cst.Name("TypeError"),
                            name=cst.AsName(cst.Name("e")),
                        ),
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            type=cst.Name("KeyError"),
                            name=cst.AsName(cst.Name("e")),
                        ),
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                        ),
                    ),
                    orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
                    finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
                ),
                "code": "try: pass\n"
                + "except TypeError as e: pass\n"
                + "except KeyError as e: pass\n"
                + "except: pass\n"
                + "else: pass\n"
                + "finally: pass\n",
                "parser": parse_statement,
                "expected_position": CodeRange((1, 0), (6, 13)),
            },
            # Verify indentation
            {
                "node": DummyIndentedBlock(
                    "    ",
                    cst.Try(
                        cst.SimpleStatementSuite((cst.Pass(),)),
                        handlers=(
                            cst.ExceptHandler(
                                cst.SimpleStatementSuite((cst.Pass(),)),
                                type=cst.Name("TypeError"),
                                name=cst.AsName(cst.Name("e")),
                            ),
                            cst.ExceptHandler(
                                cst.SimpleStatementSuite((cst.Pass(),)),
                                type=cst.Name("KeyError"),
                                name=cst.AsName(cst.Name("e")),
                            ),
                            cst.ExceptHandler(
                                cst.SimpleStatementSuite((cst.Pass(),)),
                                whitespace_after_except=cst.SimpleWhitespace(""),
                            ),
                        ),
                        orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
                        finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
                    ),
                ),
                "code": "    try: pass\n"
                + "    except TypeError as e: pass\n"
                + "    except KeyError as e: pass\n"
                + "    except: pass\n"
                + "    else: pass\n"
                + "    finally: pass\n",
                "parser": None,
            },
            # Verify indentation in bodies
            {
                "node": DummyIndentedBlock(
                    "    ",
                    cst.Try(
                        cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)),
                        handlers=(
                            cst.ExceptHandler(
                                cst.IndentedBlock(
                                    (cst.SimpleStatementLine((cst.Pass(),)),)
                                ),
                                whitespace_after_except=cst.SimpleWhitespace(""),
                            ),
                        ),
                        orelse=cst.Else(
                            cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),))
                        ),
                        finalbody=cst.Finally(
                            cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),))
                        ),
                    ),
                ),
                "code": "    try:\n"
                + "        pass\n"
                + "    except:\n"
                + "        pass\n"
                + "    else:\n"
                + "        pass\n"
                + "    finally:\n"
                + "        pass\n",
                "parser": None,
            },
            # No space when using grouping parens
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                            type=cst.Name(
                                "Exception",
                                lpar=(cst.LeftParen(),),
                                rpar=(cst.RightParen(),),
                            ),
                        ),
                    ),
                ),
                "code": "try: pass\nexcept(Exception): pass\n",
                "parser": parse_statement,
            },
            # No space when using tuple
            {
                "node": cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    handlers=(
                        cst.ExceptHandler(
                            cst.SimpleStatementSuite((cst.Pass(),)),
                            whitespace_after_except=cst.SimpleWhitespace(""),
                            type=cst.Tuple(
                                [
                                    cst.Element(
                                        cst.Name("IOError"),
                                        comma=cst.Comma(
                                            whitespace_after=cst.SimpleWhitespace(" ")
                                        ),
                                    ),
                                    cst.Element(cst.Name("ImportError")),
                                ]
                            ),
                        ),
                    ),
                ),
                "code": "try: pass\nexcept(IOError, ImportError): pass\n",
                "parser": parse_statement,
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(
        (
            {
                "get_node": lambda: cst.AsName(cst.Name("")),
                "expected_re": "empty name identifier",
            },
            {
                "get_node": lambda: cst.AsName(
                    cst.Name("bla"), whitespace_after_as=cst.SimpleWhitespace("")
                ),
                "expected_re": "between 'as'",
            },
            {
                "get_node": lambda: cst.AsName(
                    cst.Name("bla"), whitespace_before_as=cst.SimpleWhitespace("")
                ),
                "expected_re": "before 'as'",
            },
            {
                "get_node": lambda: cst.ExceptHandler(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    name=cst.AsName(cst.Name("bla")),
                ),
                "expected_re": "name for an empty type",
            },
            {
                "get_node": lambda: cst.ExceptHandler(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    type=cst.Name("TypeError"),
                    whitespace_after_except=cst.SimpleWhitespace(""),
                ),
                "expected_re": "at least one space after except",
            },
            {
                "get_node": lambda: cst.Try(cst.SimpleStatementSuite((cst.Pass(),))),
                "expected_re": "at least one ExceptHandler or Finally",
            },
            {
                "get_node": lambda: cst.Try(
                    cst.SimpleStatementSuite((cst.Pass(),)),
                    orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
                    finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
                ),
                "expected_re": "at least one ExceptHandler in order to have an Else",
            },
        )
    )
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
Esempio n. 6
0
class WithTest(CSTNodeTest):
    @data_provider((
        # Simple with block
        {
            "node":
            cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
            ),
            "code":
            "with context_mgr(): pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 24)),
        },
        # Simple async with block
        {
            "node":
            cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async with context_mgr(): pass\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.7")),
        },
        # Python 3.6 async with block
        {
            "node":
            cst.FunctionDef(
                cst.Name("foo"),
                cst.Parameters(),
                cst.IndentedBlock((cst.With(
                    (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                    cst.SimpleStatementSuite((cst.Pass(), )),
                    asynchronous=cst.Asynchronous(),
                ), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async def foo():\n    async with context_mgr(): pass\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.6")),
        },
        # Multiple context managers
        {
            "node":
            cst.With(
                (
                    cst.WithItem(cst.Call(cst.Name("foo"))),
                    cst.WithItem(cst.Call(cst.Name("bar"))),
                ),
                cst.SimpleStatementSuite((cst.Pass(), )),
            ),
            "code":
            "with foo(), bar(): pass\n",
            "parser":
            None,
        },
        {
            "node":
            cst.With(
                (
                    cst.WithItem(
                        cst.Call(cst.Name("foo")),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.WithItem(cst.Call(cst.Name("bar"))),
                ),
                cst.SimpleStatementSuite((cst.Pass(), )),
            ),
            "code":
            "with foo(), bar(): pass\n",
            "parser":
            parse_statement,
        },
        # With block containing variable for context manager.
        {
            "node":
            cst.With(
                (cst.WithItem(
                    cst.Call(cst.Name("context_mgr")),
                    cst.AsName(cst.Name("ctx")),
                ), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
            ),
            "code":
            "with context_mgr() as ctx: pass\n",
            "parser":
            parse_statement,
        },
        # indentation
        {
            "node":
            DummyIndentedBlock(
                "    ",
                cst.With(
                    (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                    cst.SimpleStatementSuite((cst.Pass(), )),
                ),
            ),
            "code":
            "    with context_mgr(): pass\n",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 4), (1, 28)),
        },
        # with an indented body
        {
            "node":
            DummyIndentedBlock(
                "    ",
                cst.With(
                    (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                    cst.IndentedBlock((cst.SimpleStatementLine(
                        (cst.Pass(), )), )),
                ),
            ),
            "code":
            "    with context_mgr():\n        pass\n",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 4), (2, 12)),
        },
        # leading_lines
        {
            "node":
            cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                leading_lines=(cst.EmptyLine(
                    comment=cst.Comment("# leading comment")), ),
            ),
            "code":
            "# leading comment\nwith context_mgr(): pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((2, 0), (2, 24)),
        },
        # Weird spacing rules
        {
            "node":
            cst.With(
                (cst.WithItem(
                    cst.Call(
                        cst.Name("context_mgr"),
                        lpar=(cst.LeftParen(), ),
                        rpar=(cst.RightParen(), ),
                    )), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_with=cst.SimpleWhitespace(""),
            ),
            "code":
            "with(context_mgr()): pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 25)),
        },
        # Whitespace
        {
            "node":
            cst.With(
                (cst.WithItem(
                    cst.Call(cst.Name("context_mgr")),
                    cst.AsName(
                        cst.Name("ctx"),
                        whitespace_before_as=cst.SimpleWhitespace("  "),
                        whitespace_after_as=cst.SimpleWhitespace("  "),
                    ),
                ), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_with=cst.SimpleWhitespace("  "),
                whitespace_before_colon=cst.SimpleWhitespace("  "),
            ),
            "code":
            "with  context_mgr()  as  ctx  : pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 36)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider((
        {
            "get_node":
            lambda: cst.With((),
                             cst.IndentedBlock((cst.SimpleStatementLine(
                                 (cst.Pass(), )), ))),
            "expected_re":
            "A With statement must have at least one WithItem",
        },
        {
            "get_node":
            lambda: cst.With(
                (cst.WithItem(
                    cst.Call(cst.Name("foo")),
                    comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")
                                    ),
                ), ),
                cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )),
            ),
            "expected_re":
            "The last WithItem in a With cannot have a trailing comma",
        },
        {
            "get_node":
            lambda: cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_with=cst.SimpleWhitespace(""),
            ),
            "expected_re":
            "Must have at least one space after with keyword",
        },
    ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)

    @data_provider((
        {
            "code": "with a, b: pass",
            "parser": parse_statement_as(python_version="3.1"),
            "expect_success": True,
        },
        {
            "code": "with a, b: pass",
            "parser": parse_statement_as(python_version="3.0"),
            "expect_success": False,
        },
    ))
    def test_versions(self, **kwargs: Any) -> None:
        self.assert_parses(**kwargs)
Esempio n. 7
0
class StatementTest(UnitTest):
    @data_provider(
        (
            # Simple imports that are already absolute.
            (None, "from a.b import c", "a.b"),
            ("x.y.z", "from a.b import c", "a.b"),
            # Relative import that can't be resolved due to missing module.
            (None, "from ..w import c", None),
            # Relative import that goes past the module level.
            ("x", "from ...y import z", None),
            ("x.y.z", "from .....w import c", None),
            ("x.y.z", "from ... import c", None),
            # Correct resolution of absolute from relative modules.
            ("x.y.z", "from . import c", "x.y"),
            ("x.y.z", "from .. import c", "x"),
            ("x.y.z", "from .w import c", "x.y.w"),
            ("x.y.z", "from ..w import c", "x.w"),
            ("x.y.z", "from ...w import c", "w"),
        )
    )
    def test_get_absolute_module(
        self, module: Optional[str], importfrom: str, output: Optional[str],
    ) -> None:
        node = ensure_type(cst.parse_statement(importfrom), cst.SimpleStatementLine)
        assert len(node.body) == 1, "Unexpected number of statements!"
        import_node = ensure_type(node.body[0], cst.ImportFrom)

        self.assertEqual(get_absolute_module_for_import(module, import_node), output)
        if output is None:
            with self.assertRaises(Exception):
                get_absolute_module_for_import_or_raise(module, import_node)
        else:
            self.assertEqual(
                get_absolute_module_for_import_or_raise(module, import_node), output
            )

    @data_provider(
        (
            # Nodes without an asname
            (cst.ImportAlias(name=cst.Name("foo")), "foo", None),
            (
                cst.ImportAlias(name=cst.Attribute(cst.Name("foo"), cst.Name("bar"))),
                "foo.bar",
                None,
            ),
            # Nodes with an asname
            (
                cst.ImportAlias(
                    name=cst.Name("foo"), asname=cst.AsName(name=cst.Name("baz"))
                ),
                "foo",
                "baz",
            ),
            (
                cst.ImportAlias(
                    name=cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                    asname=cst.AsName(name=cst.Name("baz")),
                ),
                "foo.bar",
                "baz",
            ),
        )
    )
    def test_importalias_helpers(
        self, alias_node: cst.ImportAlias, full_name: str, alias: Optional[str]
    ) -> None:
        self.assertEqual(alias_node.evaluated_name, full_name)
        self.assertEqual(alias_node.evaluated_alias, alias)
Esempio n. 8
0
class TryStarTest(CSTNodeTest):
    @data_provider((
        # Try/except with a class
        {
            "node":
            cst.TryStar(
                cst.SimpleStatementSuite((cst.Pass(), )),
                handlers=(cst.ExceptStarHandler(
                    cst.SimpleStatementSuite((cst.Pass(), )),
                    type=cst.Name("Exception"),
                ), ),
            ),
            "code":
            "try: pass\nexcept* Exception: pass\n",
            "parser":
            native_parse_statement,
        },
        # Try/except with a named class
        {
            "node":
            cst.TryStar(
                cst.SimpleStatementSuite((cst.Pass(), )),
                handlers=(cst.ExceptStarHandler(
                    cst.SimpleStatementSuite((cst.Pass(), )),
                    type=cst.Name("Exception"),
                    name=cst.AsName(cst.Name("exc")),
                ), ),
            ),
            "code":
            "try: pass\nexcept* Exception as exc: pass\n",
            "parser":
            native_parse_statement,
            "expected_position":
            CodeRange((1, 0), (2, 30)),
        },
        # Try/except with multiple clauses
        {
            "node":
            cst.TryStar(
                cst.SimpleStatementSuite((cst.Pass(), )),
                handlers=(
                    cst.ExceptStarHandler(
                        cst.SimpleStatementSuite((cst.Pass(), )),
                        type=cst.Name("TypeError"),
                        name=cst.AsName(cst.Name("e")),
                    ),
                    cst.ExceptStarHandler(
                        cst.SimpleStatementSuite((cst.Pass(), )),
                        type=cst.Name("KeyError"),
                        name=cst.AsName(cst.Name("e")),
                    ),
                ),
            ),
            "code":
            "try: pass\n" + "except* TypeError as e: pass\n" +
            "except* KeyError as e: pass\n",
            "parser":
            native_parse_statement,
            "expected_position":
            CodeRange((1, 0), (3, 27)),
        },
        # Simple try/except/finally block
        {
            "node":
            cst.TryStar(
                cst.SimpleStatementSuite((cst.Pass(), )),
                handlers=(cst.ExceptStarHandler(
                    cst.SimpleStatementSuite((cst.Pass(), )),
                    type=cst.Name("KeyError"),
                    whitespace_after_except=cst.SimpleWhitespace(""),
                ), ),
                finalbody=cst.Finally(cst.SimpleStatementSuite(
                    (cst.Pass(), ))),
            ),
            "code":
            "try: pass\nexcept* KeyError: pass\nfinally: pass\n",
            "parser":
            native_parse_statement,
            "expected_position":
            CodeRange((1, 0), (3, 13)),
        },
        # Simple try/except/else block
        {
            "node":
            cst.TryStar(
                cst.SimpleStatementSuite((cst.Pass(), )),
                handlers=(cst.ExceptStarHandler(
                    cst.SimpleStatementSuite((cst.Pass(), )),
                    type=cst.Name("KeyError"),
                    whitespace_after_except=cst.SimpleWhitespace(""),
                ), ),
                orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(), ))),
            ),
            "code":
            "try: pass\nexcept* KeyError: pass\nelse: pass\n",
            "parser":
            native_parse_statement,
            "expected_position":
            CodeRange((1, 0), (3, 10)),
        },
        # Verify whitespace in various locations
        {
            "node":
            cst.TryStar(
                leading_lines=(cst.EmptyLine(comment=cst.Comment("# 1")), ),
                body=cst.SimpleStatementSuite((cst.Pass(), )),
                handlers=(cst.ExceptStarHandler(
                    leading_lines=(cst.EmptyLine(
                        comment=cst.Comment("# 2")), ),
                    type=cst.Name("TypeError"),
                    name=cst.AsName(
                        cst.Name("e"),
                        whitespace_before_as=cst.SimpleWhitespace("  "),
                        whitespace_after_as=cst.SimpleWhitespace("  "),
                    ),
                    whitespace_after_except=cst.SimpleWhitespace("  "),
                    whitespace_after_star=cst.SimpleWhitespace(""),
                    whitespace_before_colon=cst.SimpleWhitespace(" "),
                    body=cst.SimpleStatementSuite((cst.Pass(), )),
                ), ),
                orelse=cst.Else(
                    leading_lines=(cst.EmptyLine(
                        comment=cst.Comment("# 3")), ),
                    body=cst.SimpleStatementSuite((cst.Pass(), )),
                    whitespace_before_colon=cst.SimpleWhitespace(" "),
                ),
                finalbody=cst.Finally(
                    leading_lines=(cst.EmptyLine(
                        comment=cst.Comment("# 4")), ),
                    body=cst.SimpleStatementSuite((cst.Pass(), )),
                    whitespace_before_colon=cst.SimpleWhitespace(" "),
                ),
                whitespace_before_colon=cst.SimpleWhitespace(" "),
            ),
            "code":
            "# 1\ntry : pass\n# 2\nexcept  *TypeError  as  e : pass\n# 3\nelse : pass\n# 4\nfinally : pass\n",
            "parser":
            native_parse_statement,
            "expected_position":
            CodeRange((2, 0), (8, 14)),
        },
        # Now all together
        {
            "node":
            cst.TryStar(
                cst.SimpleStatementSuite((cst.Pass(), )),
                handlers=(
                    cst.ExceptStarHandler(
                        cst.SimpleStatementSuite((cst.Pass(), )),
                        type=cst.Name("TypeError"),
                        name=cst.AsName(cst.Name("e")),
                    ),
                    cst.ExceptStarHandler(
                        cst.SimpleStatementSuite((cst.Pass(), )),
                        type=cst.Name("KeyError"),
                        name=cst.AsName(cst.Name("e")),
                    ),
                ),
                orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(), ))),
                finalbody=cst.Finally(cst.SimpleStatementSuite(
                    (cst.Pass(), ))),
            ),
            "code":
            "try: pass\n" + "except* TypeError as e: pass\n" +
            "except* KeyError as e: pass\n" + "else: pass\n" +
            "finally: pass\n",
            "parser":
            native_parse_statement,
            "expected_position":
            CodeRange((1, 0), (5, 13)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)
Esempio n. 9
0
class ImportFromParseTest(CSTNodeTest):
    @data_provider(
        (
            # Simple from import statement
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),)
                ),
                "code": "from foo import bar",
            },
            # From import statement with alias
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"), asname=cst.AsName(cst.Name("baz"))
                        ),
                    ),
                ),
                "code": "from foo import bar as baz",
            },
            # Multiple imports
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(cst.Name("baz")),
                    ),
                ),
                "code": "from foo import bar, baz",
            },
            # Trailing comma
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()),
                    ),
                ),
                "code": "from foo import bar, baz,",
            },
            # Star import statement
            {
                "node": cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()),
                "code": "from foo import *",
            },
            # Simple relative import statement
            {
                "node": cst.ImportFrom(
                    relative=(cst.Dot(),),
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                ),
                "code": "from .foo import bar",
            },
            {
                "node": cst.ImportFrom(
                    relative=(cst.Dot(), cst.Dot()),
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                ),
                "code": "from ..foo import bar",
            },
            # Relative only import
            {
                "node": cst.ImportFrom(
                    relative=(cst.Dot(), cst.Dot()),
                    module=None,
                    names=(cst.ImportAlias(cst.Name("bar")),),
                ),
                "code": "from .. import bar",
            },
            # Parenthesis
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    lpar=cst.LeftParen(),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"), asname=cst.AsName(cst.Name("baz"))
                        ),
                    ),
                    rpar=cst.RightParen(),
                ),
                "code": "from foo import (bar as baz)",
            },
            # Verify whitespace works everywhere.
            {
                "node": cst.ImportFrom(
                    relative=(
                        cst.Dot(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace("  "),
                        ),
                        cst.Dot(
                            whitespace_before=cst.SimpleWhitespace(""),
                            whitespace_after=cst.SimpleWhitespace(" "),
                        ),
                    ),
                    module=cst.Name("foo"),
                    lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"),
                            asname=cst.AsName(
                                cst.Name("baz"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                            comma=cst.Comma(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace("  "),
                            ),
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"),
                            asname=cst.AsName(
                                cst.Name("ut"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                        ),
                    ),
                    rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),
                    whitespace_after_from=cst.SimpleWhitespace("   "),
                    whitespace_before_import=cst.SimpleWhitespace("  "),
                    whitespace_after_import=cst.SimpleWhitespace("  "),
                ),
                "code": "from   .  . foo  import  ( bar  as  baz ,  unittest  as  ut )",
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(
            parser=lambda code: ensure_type(
                parse_statement(code), cst.SimpleStatementLine
            ).body[0],
            **kwargs,
        )
Esempio n. 10
0
class ImportFromCreateTest(CSTNodeTest):
    @data_provider(
        (
            # Simple from import statement
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),)
                ),
                "code": "from foo import bar",
            },
            # From import statement with alias
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"), asname=cst.AsName(cst.Name("baz"))
                        ),
                    ),
                ),
                "code": "from foo import bar as baz",
            },
            # Multiple imports
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(
                        cst.ImportAlias(cst.Name("bar")),
                        cst.ImportAlias(cst.Name("baz")),
                    ),
                ),
                "code": "from foo import bar, baz",
            },
            # Trailing comma
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(
                        cst.ImportAlias(cst.Name("bar"), comma=cst.Comma()),
                        cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()),
                    ),
                ),
                "code": "from foo import bar,baz,",
                "expected_position": CodeRange((1, 0), (1, 23)),
            },
            # Star import statement
            {
                "node": cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()),
                "code": "from foo import *",
                "expected_position": CodeRange((1, 0), (1, 17)),
            },
            # Simple relative import statement
            {
                "node": cst.ImportFrom(
                    relative=(cst.Dot(),),
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                ),
                "code": "from .foo import bar",
            },
            {
                "node": cst.ImportFrom(
                    relative=(cst.Dot(), cst.Dot()),
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                ),
                "code": "from ..foo import bar",
            },
            # Relative only import
            {
                "node": cst.ImportFrom(
                    relative=(cst.Dot(), cst.Dot()),
                    module=None,
                    names=(cst.ImportAlias(cst.Name("bar")),),
                ),
                "code": "from .. import bar",
            },
            # Parenthesis
            {
                "node": cst.ImportFrom(
                    module=cst.Name("foo"),
                    lpar=cst.LeftParen(),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"), asname=cst.AsName(cst.Name("baz"))
                        ),
                    ),
                    rpar=cst.RightParen(),
                ),
                "code": "from foo import (bar as baz)",
                "expected_position": CodeRange((1, 0), (1, 28)),
            },
            # Verify whitespace works everywhere.
            {
                "node": cst.ImportFrom(
                    relative=(
                        cst.Dot(
                            whitespace_before=cst.SimpleWhitespace(" "),
                            whitespace_after=cst.SimpleWhitespace(" "),
                        ),
                        cst.Dot(
                            whitespace_before=cst.SimpleWhitespace(" "),
                            whitespace_after=cst.SimpleWhitespace(" "),
                        ),
                    ),
                    module=cst.Name("foo"),
                    lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),
                    names=(
                        cst.ImportAlias(
                            cst.Name("bar"),
                            asname=cst.AsName(
                                cst.Name("baz"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                            comma=cst.Comma(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace("  "),
                            ),
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"),
                            asname=cst.AsName(
                                cst.Name("ut"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                        ),
                    ),
                    rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),
                    whitespace_after_from=cst.SimpleWhitespace("  "),
                    whitespace_before_import=cst.SimpleWhitespace("  "),
                    whitespace_after_import=cst.SimpleWhitespace("  "),
                ),
                "code": "from   .  . foo  import  ( bar  as  baz ,  unittest  as  ut )",
                "expected_position": CodeRange((1, 0), (1, 61)),
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(
        (
            {
                "get_node": lambda: cst.ImportFrom(
                    module=None, names=(cst.ImportAlias(cst.Name("bar")),)
                ),
                "expected_re": "Must have a module specified",
            },
            {
                "get_node": lambda: cst.ImportFrom(module=cst.Name("foo"), names=()),
                "expected_re": "at least one ImportAlias",
            },
            {
                "get_node": lambda: cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                    lpar=cst.LeftParen(),
                ),
                "expected_re": "left paren without right paren",
            },
            {
                "get_node": lambda: cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                    rpar=cst.RightParen(),
                ),
                "expected_re": "right paren without left paren",
            },
            {
                "get_node": lambda: cst.ImportFrom(
                    module=cst.Name("foo"), names=cst.ImportStar(), lpar=cst.LeftParen()
                ),
                "expected_re": "cannot have parens",
            },
            {
                "get_node": lambda: cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=cst.ImportStar(),
                    rpar=cst.RightParen(),
                ),
                "expected_re": "cannot have parens",
            },
            {
                "get_node": lambda: cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                    whitespace_after_from=cst.SimpleWhitespace(""),
                ),
                "expected_re": "one space after from",
            },
            {
                "get_node": lambda: cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                    whitespace_before_import=cst.SimpleWhitespace(""),
                ),
                "expected_re": "one space before import",
            },
            {
                "get_node": lambda: cst.ImportFrom(
                    module=cst.Name("foo"),
                    names=(cst.ImportAlias(cst.Name("bar")),),
                    whitespace_after_import=cst.SimpleWhitespace(""),
                ),
                "expected_re": "one space after import",
            },
        )
    )
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
Esempio n. 11
0
class ImportParseTest(CSTNodeTest):
    @data_provider(
        (
            # Simple import statement
            {
                "node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)),
                "code": "import foo",
            },
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    )
                ),
                "code": "import foo.bar",
            },
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    )
                ),
                "code": "import foo.bar",
            },
            # Comma-separated list of imports
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz"))
                        ),
                    )
                ),
                "code": "import foo.bar, foo.baz",
            },
            # Import with an alias
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                        ),
                    )
                ),
                "code": "import foo.bar as baz",
            },
            # Import with an alias, comma separated
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz")),
                            asname=cst.AsName(cst.Name("bar")),
                        ),
                    )
                ),
                "code": "import foo.bar as baz, foo.baz as bar",
            },
            # Combine for fun and profit
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("insta"), cst.Name("gram")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz")),
                            comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"), asname=cst.AsName(cst.Name("ut"))
                        ),
                    )
                ),
                "code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut",
            },
            # Verify whitespace works everywhere.
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(
                                cst.Name("foo"),
                                cst.Name("bar"),
                                dot=cst.Dot(
                                    whitespace_before=cst.SimpleWhitespace(" "),
                                    whitespace_after=cst.SimpleWhitespace(" "),
                                ),
                            ),
                            asname=cst.AsName(
                                cst.Name("baz"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                            comma=cst.Comma(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace("  "),
                            ),
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"),
                            asname=cst.AsName(
                                cst.Name("ut"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                        ),
                    ),
                    whitespace_after_import=cst.SimpleWhitespace("  "),
                ),
                "code": "import  foo . bar  as  baz ,  unittest  as  ut",
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(
            parser=lambda code: ensure_type(
                parse_statement(code), cst.SimpleStatementLine
            ).body[0],
            **kwargs,
        )
Esempio n. 12
0
class ImportCreateTest(CSTNodeTest):
    @data_provider(
        (
            # Simple import statement
            {
                "node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)),
                "code": "import foo",
            },
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    )
                ),
                "code": "import foo.bar",
            },
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    )
                ),
                "code": "import foo.bar",
            },
            # Comma-separated list of imports
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz"))
                        ),
                    )
                ),
                "code": "import foo.bar, foo.baz",
                "expected_position": CodeRange((1, 0), (1, 23)),
            },
            # Import with an alias
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                        ),
                    )
                ),
                "code": "import foo.bar as baz",
            },
            # Import with an alias, comma separated
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz")),
                            asname=cst.AsName(cst.Name("bar")),
                        ),
                    )
                ),
                "code": "import foo.bar as baz, foo.baz as bar",
            },
            # Combine for fun and profit
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            asname=cst.AsName(cst.Name("baz")),
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("insta"), cst.Name("gram"))
                        ),
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("baz"))
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"), asname=cst.AsName(cst.Name("ut"))
                        ),
                    )
                ),
                "code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut",
            },
            # Verify whitespace works everywhere.
            {
                "node": cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(
                                cst.Name("foo"),
                                cst.Name("bar"),
                                dot=cst.Dot(
                                    whitespace_before=cst.SimpleWhitespace(" "),
                                    whitespace_after=cst.SimpleWhitespace(" "),
                                ),
                            ),
                            asname=cst.AsName(
                                cst.Name("baz"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                            comma=cst.Comma(
                                whitespace_before=cst.SimpleWhitespace(" "),
                                whitespace_after=cst.SimpleWhitespace("  "),
                            ),
                        ),
                        cst.ImportAlias(
                            cst.Name("unittest"),
                            asname=cst.AsName(
                                cst.Name("ut"),
                                whitespace_before_as=cst.SimpleWhitespace("  "),
                                whitespace_after_as=cst.SimpleWhitespace("  "),
                            ),
                        ),
                    ),
                    whitespace_after_import=cst.SimpleWhitespace("  "),
                ),
                "code": "import  foo . bar  as  baz ,  unittest  as  ut",
                "expected_position": CodeRange((1, 0), (1, 46)),
            },
        )
    )
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider(
        (
            {
                "get_node": lambda: cst.Import(names=()),
                "expected_re": "at least one ImportAlias",
            },
            {
                "get_node": lambda: cst.Import(names=(cst.ImportAlias(cst.Name("")),)),
                "expected_re": "empty name identifier",
            },
            {
                "get_node": lambda: cst.Import(
                    names=(
                        cst.ImportAlias(cst.Attribute(cst.Name(""), cst.Name("bla"))),
                    )
                ),
                "expected_re": "empty name identifier",
            },
            {
                "get_node": lambda: cst.Import(
                    names=(
                        cst.ImportAlias(cst.Attribute(cst.Name("bla"), cst.Name(""))),
                    )
                ),
                "expected_re": "empty name identifier",
            },
            {
                "get_node": lambda: cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar")),
                            comma=cst.Comma(),
                        ),
                    )
                ),
                "expected_re": "trailing comma",
            },
            {
                "get_node": lambda: cst.Import(
                    names=(
                        cst.ImportAlias(
                            cst.Attribute(cst.Name("foo"), cst.Name("bar"))
                        ),
                    ),
                    whitespace_after_import=cst.SimpleWhitespace(""),
                ),
                "expected_re": "at least one space",
            },
        )
    )
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)
Esempio n. 13
0
class WithTest(CSTNodeTest):
    maxDiff: int = 2000

    @data_provider((
        # Simple with block
        {
            "node":
            cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
            ),
            "code":
            "with context_mgr(): pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 24)),
        },
        # Simple async with block
        {
            "node":
            cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async with context_mgr(): pass\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.7")),
        },
        # Python 3.6 async with block
        {
            "node":
            cst.FunctionDef(
                cst.Name("foo"),
                cst.Parameters(),
                cst.IndentedBlock((cst.With(
                    (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                    cst.SimpleStatementSuite((cst.Pass(), )),
                    asynchronous=cst.Asynchronous(),
                ), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async def foo():\n    async with context_mgr(): pass\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.6")),
        },
        # Multiple context managers
        {
            "node":
            cst.With(
                (
                    cst.WithItem(cst.Call(cst.Name("foo"))),
                    cst.WithItem(cst.Call(cst.Name("bar"))),
                ),
                cst.SimpleStatementSuite((cst.Pass(), )),
            ),
            "code":
            "with foo(), bar(): pass\n",
            "parser":
            None,
        },
        {
            "node":
            cst.With(
                (
                    cst.WithItem(
                        cst.Call(cst.Name("foo")),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.WithItem(cst.Call(cst.Name("bar"))),
                ),
                cst.SimpleStatementSuite((cst.Pass(), )),
            ),
            "code":
            "with foo(), bar(): pass\n",
            "parser":
            parse_statement,
        },
        # With block containing variable for context manager.
        {
            "node":
            cst.With(
                (cst.WithItem(
                    cst.Call(cst.Name("context_mgr")),
                    cst.AsName(cst.Name("ctx")),
                ), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
            ),
            "code":
            "with context_mgr() as ctx: pass\n",
            "parser":
            parse_statement,
        },
        # indentation
        {
            "node":
            DummyIndentedBlock(
                "    ",
                cst.With(
                    (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                    cst.SimpleStatementSuite((cst.Pass(), )),
                ),
            ),
            "code":
            "    with context_mgr(): pass\n",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 4), (1, 28)),
        },
        # with an indented body
        {
            "node":
            DummyIndentedBlock(
                "    ",
                cst.With(
                    (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                    cst.IndentedBlock((cst.SimpleStatementLine(
                        (cst.Pass(), )), )),
                ),
            ),
            "code":
            "    with context_mgr():\n        pass\n",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 4), (2, 12)),
        },
        # leading_lines
        {
            "node":
            cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                leading_lines=(cst.EmptyLine(
                    comment=cst.Comment("# leading comment")), ),
            ),
            "code":
            "# leading comment\nwith context_mgr(): pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((2, 0), (2, 24)),
        },
        # Whitespace
        {
            "node":
            cst.With(
                (cst.WithItem(
                    cst.Call(cst.Name("context_mgr")),
                    cst.AsName(
                        cst.Name("ctx"),
                        whitespace_before_as=cst.SimpleWhitespace("  "),
                        whitespace_after_as=cst.SimpleWhitespace("  "),
                    ),
                ), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_with=cst.SimpleWhitespace("  "),
                whitespace_before_colon=cst.SimpleWhitespace("  "),
            ),
            "code":
            "with  context_mgr()  as  ctx  : pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 36)),
        },
        # Weird spacing rules, that parse differently depending on whether
        # we are using a grammar that included parenthesized with statements.
        {
            "node":
            cst.With(
                (cst.WithItem(
                    cst.Call(
                        cst.Name("context_mgr"),
                        lpar=() if is_native() else (cst.LeftParen(), ),
                        rpar=() if is_native() else (cst.RightParen(), ),
                    )), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=(cst.LeftParen()
                      if is_native() else MaybeSentinel.DEFAULT),
                rpar=(cst.RightParen()
                      if is_native() else MaybeSentinel.DEFAULT),
                whitespace_after_with=cst.SimpleWhitespace(""),
            ),
            "code":
            "with(context_mgr()): pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 25)),
        },
        # Multi-line parenthesized with.
        {
            "node":
            cst.With(
                (
                    cst.WithItem(
                        cst.Call(cst.Name("foo")),
                        comma=cst.
                        Comma(whitespace_after=cst.ParenthesizedWhitespace(
                            first_line=cst.TrailingWhitespace(
                                whitespace=cst.SimpleWhitespace(value="", ),
                                comment=None,
                                newline=cst.Newline(value=None, ),
                            ),
                            empty_lines=[],
                            indent=True,
                            last_line=cst.SimpleWhitespace(value="       ", ),
                        )),
                    ),
                    cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()),
                ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),
                rpar=cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")),
            ),
            "code": ("with ( foo(),\n"
                     "       bar(), ): pass\n"),  # noqa
            "parser":
            parse_statement if is_native() else None,
            "expected_position":
            CodeRange((1, 0), (2, 21)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider((
        {
            "get_node":
            lambda: cst.With((),
                             cst.IndentedBlock((cst.SimpleStatementLine(
                                 (cst.Pass(), )), ))),
            "expected_re":
            "A With statement must have at least one WithItem",
        },
        {
            "get_node":
            lambda: cst.With(
                (cst.WithItem(
                    cst.Call(cst.Name("foo")),
                    comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")
                                    ),
                ), ),
                cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )),
            ),
            "expected_re":
            "The last WithItem in an unparenthesized With cannot " +
            "have a trailing comma.",
        },
        {
            "get_node":
            lambda: cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_with=cst.SimpleWhitespace(""),
            ),
            "expected_re":
            "Must have at least one space after with keyword",
        },
        {
            "get_node":
            lambda: cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_with=cst.SimpleWhitespace(""),
                lpar=cst.LeftParen(),
            ),
            "expected_re":
            "Do not mix concrete LeftParen/RightParen with " + "MaybeSentinel",
        },
        {
            "get_node":
            lambda: cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_with=cst.SimpleWhitespace(""),
                rpar=cst.RightParen(),
            ),
            "expected_re":
            "Do not mix concrete LeftParen/RightParen with " + "MaybeSentinel",
        },
    ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)

    @data_provider((
        {
            "code": "with a, b: pass",
            "parser": parse_statement_as(python_version="3.1"),
            "expect_success": True,
        },
        {
            "code": "with a, b: pass",
            "parser": parse_statement_as(python_version="3.0"),
            "expect_success": False,
        },
    ))
    def test_versions(self, **kwargs: Any) -> None:
        if is_native() and not kwargs.get("expect_success", True):
            self.skipTest("parse errors are disabled for native parser")
        self.assert_parses(**kwargs)

    def test_adding_parens(self) -> None:
        node = cst.With(
            (
                cst.WithItem(
                    cst.Call(cst.Name("foo")),
                    comma=cst.Comma(
                        whitespace_after=cst.ParenthesizedWhitespace(), ),
                ),
                cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()),
            ),
            cst.SimpleStatementSuite((cst.Pass(), )),
            lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),
            rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),
        )
        module = cst.Module([])
        self.assertEqual(
            module.code_for_node(node),
            ("with ( foo(),\n"
             "bar(), ): pass\n")  # noqa
        )