Exemplo n.º 1
0
    def leave_Module(
        self,
        original_node: cst.Module,
        updated_node: cst.Module,
    ) -> cst.Module:
        fresh_class_definitions = [
            definition
            for name, definition in self.annotations.class_definitions.items()
            if name not in self.visited_classes
        ]

        # NOTE: The entire change will also be abandoned if
        # self.annotation_counts is all 0s, so if adding any new category make
        # sure to record it there.
        if not (self.toplevel_annotations or fresh_class_definitions
                or self.annotations.typevars):
            return updated_node

        toplevel_statements = []
        # First, find the insertion point for imports
        statements_before_imports, statements_after_imports = self._split_module(
            original_node, updated_node)

        # Make sure there's at least one empty line before the first non-import
        statements_after_imports = self._insert_empty_line(
            statements_after_imports)

        for name, annotation in self.toplevel_annotations.items():
            annotated_assign = self._apply_annotation_to_attribute_or_global(
                name=name,
                annotation=annotation,
                value=None,
            )
            toplevel_statements.append(
                cst.SimpleStatementLine([annotated_assign]))

        # TypeVar definitions could be scattered through the file, so do not
        # attempt to put new ones with existing ones, just add them at the top.
        typevars = {
            k: v
            for k, v in self.annotations.typevars.items()
            if k not in self.typevars
        }
        if typevars:
            for var, stmt in typevars.items():
                toplevel_statements.append(cst.Newline())
                toplevel_statements.append(stmt)
                self.annotation_counts.typevars_and_generics_added += 1
            toplevel_statements.append(cst.Newline())

        self.annotation_counts.classes_added = len(fresh_class_definitions)
        toplevel_statements.extend(fresh_class_definitions)

        return updated_node.with_changes(body=[
            *statements_before_imports,
            *toplevel_statements,
            *statements_after_imports,
        ])
Exemplo n.º 2
0
 def _create_empty_line():
     return cst.EmptyLine(
         indent=True,
         whitespace=cst.SimpleWhitespace(value='', ),
         comment=None,
         newline=cst.Newline(value=None, ),
     )
class TrailingWhitespaceTest(CSTNodeTest):
    @data_provider(
        (
            (cst.TrailingWhitespace(), "\n"),
            (cst.TrailingWhitespace(whitespace=cst.SimpleWhitespace("    ")), "    \n"),
            (cst.TrailingWhitespace(comment=cst.Comment("# comment")), "# comment\n"),
            (cst.TrailingWhitespace(newline=cst.Newline("\r\n")), "\r\n"),
            (
                cst.TrailingWhitespace(
                    whitespace=cst.SimpleWhitespace("    "),
                    comment=cst.Comment("# comment"),
                    newline=cst.Newline("\r\n"),
                ),
                "    # comment\r\n",
            ),
        )
    )
    def test_valid(self, node: cst.CSTNode, code: str) -> None:
        self.validate_node(node, code)
Exemplo n.º 4
0
class EmptyLineTest(CSTNodeTest):
    @data_provider((
        (cst.EmptyLine(), "\n"),
        (cst.EmptyLine(whitespace=cst.SimpleWhitespace("    ")), "    \n"),
        (cst.EmptyLine(comment=cst.Comment("# comment")), "# comment\n"),
        (cst.EmptyLine(newline=cst.Newline("\r\n")), "\r\n"),
        (DummyIndentedBlock("  ", cst.EmptyLine()), "  \n"),
        (DummyIndentedBlock("  ", cst.EmptyLine(indent=False)), "\n"),
        (
            DummyIndentedBlock(
                "\t",
                cst.EmptyLine(
                    whitespace=cst.SimpleWhitespace("    "),
                    comment=cst.Comment("# comment"),
                    newline=cst.Newline("\r\n"),
                ),
            ),
            "\t    # comment\r\n",
        ),
    ))
    def test_valid(self, node: cst.CSTNode, code: str) -> None:
        self.validate_node(node, code)
Exemplo n.º 5
0
 def test_module_config_for_parsing(self) -> None:
     module = parse_module("pass\r")
     statement = parse_statement("if True:\r    pass",
                                 config=module.config_for_parsing)
     self.assertEqual(
         statement,
         cst.If(
             test=cst.Name(value="True"),
             body=cst.IndentedBlock(
                 body=[cst.SimpleStatementLine(body=[cst.Pass()])],
                 header=cst.TrailingWhitespace(newline=cst.Newline(
                     # This would be "\r" if we didn't pass the module config forward.
                     value=None)),
             ),
         ),
     )
Exemplo n.º 6
0
    def test_with_changes(self) -> None:
        initial = cst.TrailingWhitespace(
            whitespace=cst.SimpleWhitespace("  \\\n  "),
            comment=cst.Comment("# initial"),
            newline=cst.Newline("\r\n"),
        )
        changed = initial.with_changes(comment=cst.Comment("# new comment"))

        # see that we have the updated fields
        self.assertEqual(none_throws(changed.comment).value, "# new comment")
        # and that the old fields are still there
        self.assertEqual(changed.whitespace.value, "  \\\n  ")
        self.assertEqual(changed.newline.value, "\r\n")

        # ensure no mutation actually happened
        self.assertEqual(none_throws(initial.comment).value, "# initial")
Exemplo n.º 7
0
 def leave_SimpleWhitespace(
     self,
     original_node: libcst.SimpleWhitespace,
     updated_node: libcst.SimpleWhitespace,
 ) -> Union[libcst.SimpleWhitespace, libcst.ParenthesizedWhitespace]:
     whitespace = original_node.value.replace("\\", "")
     if "\n" in whitespace:
         first_line = libcst.TrailingWhitespace(
             whitespace=libcst.SimpleWhitespace(
                 value=whitespace.split("\n")[0].rstrip()
             ),
             comment=None,
             newline=libcst.Newline(),
         )
         last_line = libcst.SimpleWhitespace(value=whitespace.split("\n")[1])
         return libcst.ParenthesizedWhitespace(
             first_line=first_line, empty_lines=[], indent=True, last_line=last_line
         )
     return updated_node
Exemplo n.º 8
0
class NewlineTest(CSTNodeTest):
    @data_provider((
        (cst.Newline("\r\n"), "\r\n"),
        (cst.Newline("\r"), "\r"),
        (cst.Newline("\n"), "\n"),
    ))
    def test_valid(self, node: cst.CSTNode, code: str) -> None:
        self.validate_node(node, code)

    @data_provider((
        (lambda: cst.Newline("bad input"), "invalid value"),
        (lambda: cst.Newline("\nbad input\n"), "invalid value"),
        (lambda: cst.Newline("\n\n"), "invalid value"),
    ))
    def test_invalid(self, get_node: Callable[[], cst.CSTNode],
                     expected_re: str) -> None:
        self.assert_invalid(get_node, expected_re)
Exemplo n.º 9
0
class ModuleTest(CSTNodeTest):
    @data_provider((
        # simplest possible program
        (cst.Module((cst.SimpleStatementLine((cst.Pass(), )), )), "pass\n"),
        # test default_newline
        (
            cst.Module((cst.SimpleStatementLine((cst.Pass(), )), ),
                       default_newline="\r"),
            "pass\r",
        ),
        # test header/footer
        (
            cst.Module(
                (cst.SimpleStatementLine((cst.Pass(), )), ),
                header=(cst.EmptyLine(comment=cst.Comment("# header")), ),
                footer=(cst.EmptyLine(comment=cst.Comment("# footer")), ),
            ),
            "# header\npass\n# footer\n",
        ),
        # test has_trailing_newline
        (
            cst.Module(
                (cst.SimpleStatementLine((cst.Pass(), )), ),
                has_trailing_newline=False,
            ),
            "pass",
        ),
        # an empty file
        (cst.Module((), has_trailing_newline=False), ""),
        # a file with only comments
        (
            cst.Module(
                (),
                header=(cst.EmptyLine(
                    comment=cst.Comment("# nothing to see here")), ),
            ),
            "# nothing to see here\n",
        ),
        # TODO: test default_indent
    ))
    def test_code_and_bytes_properties(self, module: cst.Module,
                                       expected: str) -> None:
        self.assertEqual(module.code, expected)
        self.assertEqual(module.bytes, expected.encode("utf-8"))

    @data_provider((
        (cst.Module(()), cst.Newline(), "\n"),
        (cst.Module((), default_newline="\r\n"), cst.Newline(), "\r\n"),
        # has_trailing_newline has no effect on code_for_node
        (cst.Module((), has_trailing_newline=False), cst.Newline(), "\n"),
        # TODO: test default_indent
    ))
    def test_code_for_node(self, module: cst.Module, node: cst.CSTNode,
                           expected: str) -> None:
        self.assertEqual(module.code_for_node(node), expected)

    @data_provider({
        "empty_program": {
            "code": "",
            "expected": cst.Module([], has_trailing_newline=False),
        },
        "empty_program_with_newline": {
            "code": "\n",
            "expected": cst.Module([], has_trailing_newline=True),
            "enabled_for_native": False,
        },
        "empty_program_with_comments": {
            "code":
            "# some comment\n",
            "expected":
            cst.Module(
                [],
                header=[cst.EmptyLine(comment=cst.Comment("# some comment"))]),
        },
        "simple_pass": {
            "code": "pass\n",
            "expected": cst.Module([cst.SimpleStatementLine([cst.Pass()])]),
        },
        "simple_pass_with_header_footer": {
            "code":
            "# header\npass # trailing\n# footer\n",
            "expected":
            cst.Module(
                [
                    cst.SimpleStatementLine(
                        [cst.Pass()],
                        trailing_whitespace=cst.TrailingWhitespace(
                            whitespace=cst.SimpleWhitespace(" "),
                            comment=cst.Comment("# trailing"),
                        ),
                    )
                ],
                header=[cst.EmptyLine(comment=cst.Comment("# header"))],
                footer=[cst.EmptyLine(comment=cst.Comment("# footer"))],
            ),
        },
    })
    def test_parser(self,
                    *,
                    code: str,
                    expected: cst.Module,
                    enabled_for_native: bool = True) -> None:
        if is_native() and not enabled_for_native:
            self.skipTest("Disabled for native parser")
        self.assertEqual(parse_module(code), expected)

    @data_provider({
        "empty": {
            "code": "",
            "expected": CodeRange((1, 0), (1, 0))
        },
        "empty_with_newline": {
            "code": "\n",
            "expected": CodeRange((1, 0), (2, 0))
        },
        "empty_program_with_comments": {
            "code": "# 2345",
            "expected": CodeRange((1, 0), (2, 0)),
        },
        "simple_pass": {
            "code": "pass\n",
            "expected": CodeRange((1, 0), (2, 0))
        },
        "simple_pass_with_header_footer": {
            "code": "# header\npass # trailing\n# footer\n",
            "expected": CodeRange((1, 0), (4, 0)),
        },
    })
    def test_module_position(self, *, code: str, expected: CodeRange) -> None:
        wrapper = MetadataWrapper(parse_module(code))
        positions = wrapper.resolve(PositionProvider)

        self.assertEqual(positions[wrapper.module], expected)

    def cmp_position(self, actual: CodeRange, start: Tuple[int, int],
                     end: Tuple[int, int]) -> None:
        self.assertEqual(actual, CodeRange(start, end))

    def test_function_position(self) -> None:
        wrapper = MetadataWrapper(parse_module("def foo():\n    pass"))
        module = wrapper.module
        positions = wrapper.resolve(PositionProvider)

        fn = cast(cst.FunctionDef, module.body[0])
        stmt = cast(cst.SimpleStatementLine, fn.body.body[0])
        pass_stmt = cast(cst.Pass, stmt.body[0])
        self.cmp_position(positions[stmt], (2, 4), (2, 8))
        self.cmp_position(positions[pass_stmt], (2, 4), (2, 8))

    def test_nested_indent_position(self) -> None:
        wrapper = MetadataWrapper(
            parse_module(
                "if True:\n    if False:\n        x = 1\nelse:\n    return"))
        module = wrapper.module
        positions = wrapper.resolve(PositionProvider)

        outer_if = cast(cst.If, module.body[0])
        inner_if = cast(cst.If, outer_if.body.body[0])
        assign = cast(cst.SimpleStatementLine, inner_if.body.body[0]).body[0]

        outer_else = cast(cst.Else, outer_if.orelse)
        return_stmt = cast(cst.SimpleStatementLine,
                           outer_else.body.body[0]).body[0]

        self.cmp_position(positions[outer_if], (1, 0), (5, 10))
        self.cmp_position(positions[inner_if], (2, 4), (3, 13))
        self.cmp_position(positions[assign], (3, 8), (3, 13))
        self.cmp_position(positions[outer_else], (4, 0), (5, 10))
        self.cmp_position(positions[return_stmt], (5, 4), (5, 10))

    def test_multiline_string_position(self) -> None:
        wrapper = MetadataWrapper(parse_module('"abc"\\\n"def"'))
        module = wrapper.module
        positions = wrapper.resolve(PositionProvider)

        stmt = cast(cst.SimpleStatementLine, module.body[0])
        expr = cast(cst.Expr, stmt.body[0])
        string = expr.value

        self.cmp_position(positions[stmt], (1, 0), (2, 5))
        self.cmp_position(positions[expr], (1, 0), (2, 5))
        self.cmp_position(positions[string], (1, 0), (2, 5))

    def test_module_config_for_parsing(self) -> None:
        module = parse_module("pass\r")
        statement = parse_statement("if True:\r    pass",
                                    config=module.config_for_parsing)
        self.assertEqual(
            statement,
            cst.If(
                test=cst.Name(value="True"),
                body=cst.IndentedBlock(
                    body=[cst.SimpleStatementLine(body=[cst.Pass()])],
                    header=cst.TrailingWhitespace(newline=cst.Newline(
                        # This would be "\r" if we didn't pass the module config forward.
                        value=None)),
                ),
            ),
        )
Exemplo n.º 10
0
class ModuleTest(CSTNodeTest):
    @data_provider((
        # simplest possible program
        (cst.Module((cst.SimpleStatementLine((cst.Pass(), )), )), "pass\n"),
        # test default_newline
        (
            cst.Module((cst.SimpleStatementLine((cst.Pass(), )), ),
                       default_newline="\r"),
            "pass\r",
        ),
        # test header/footer
        (
            cst.Module(
                (cst.SimpleStatementLine((cst.Pass(), )), ),
                header=(cst.EmptyLine(comment=cst.Comment("# header")), ),
                footer=(cst.EmptyLine(comment=cst.Comment("# footer")), ),
            ),
            "# header\npass\n# footer\n",
        ),
        # test has_trailing_newline
        (
            cst.Module(
                (cst.SimpleStatementLine((cst.Pass(), )), ),
                has_trailing_newline=False,
            ),
            "pass",
        ),
        # an empty file
        (cst.Module((), has_trailing_newline=False), ""),
        # a file with only comments
        (
            cst.Module(
                (),
                header=(cst.EmptyLine(
                    comment=cst.Comment("# nothing to see here")), ),
            ),
            "# nothing to see here\n",
        ),
        # TODO: test default_indent
    ))
    def test_code_and_bytes_properties(self, module: cst.Module,
                                       expected: str) -> None:
        self.assertEqual(module.code, expected)
        self.assertEqual(module.bytes, expected.encode("utf-8"))

    @data_provider((
        (cst.Module(()), cst.Newline(), "\n"),
        (cst.Module((), default_newline="\r\n"), cst.Newline(), "\r\n"),
        # has_trailing_newline has no effect on code_for_node
        (cst.Module((), has_trailing_newline=False), cst.Newline(), "\n"),
        # TODO: test default_indent
    ))
    def test_code_for_node(self, module: cst.Module, node: cst.CSTNode,
                           expected: str) -> None:
        self.assertEqual(module.code_for_node(node), expected)

    @data_provider({
        "empty_program": {
            "code": "",
            "expected": cst.Module([], has_trailing_newline=False),
        },
        "empty_program_with_newline": {
            "code": "\n",
            "expected": cst.Module([], has_trailing_newline=True),
        },
        "empty_program_with_comments": {
            "code":
            "# some comment\n",
            "expected":
            cst.Module(
                [],
                header=[cst.EmptyLine(comment=cst.Comment("# some comment"))]),
        },
        "simple_pass": {
            "code": "pass\n",
            "expected": cst.Module([cst.SimpleStatementLine([cst.Pass()])]),
        },
        "simple_pass_with_header_footer": {
            "code":
            "# header\npass # trailing\n# footer\n",
            "expected":
            cst.Module(
                [
                    cst.SimpleStatementLine(
                        [cst.Pass()],
                        trailing_whitespace=cst.TrailingWhitespace(
                            whitespace=cst.SimpleWhitespace(" "),
                            comment=cst.Comment("# trailing"),
                        ),
                    )
                ],
                header=[cst.EmptyLine(comment=cst.Comment("# header"))],
                footer=[cst.EmptyLine(comment=cst.Comment("# footer"))],
            ),
        },
    })
    def test_parser(self, *, code: str, expected: cst.Module) -> None:
        self.assertEqual(parse_module(code), expected)

    @data_provider({
        "empty": {
            "code": "",
            "expected": CodeRange.create((1, 0), (1, 0))
        },
        "empty_with_newline": {
            "code": "\n",
            "expected": CodeRange.create((1, 0), (2, 0)),
        },
        "empty_program_with_comments": {
            "code": "# 2345",
            "expected": CodeRange.create((1, 0), (2, 0)),
        },
        "simple_pass": {
            "code": "pass\n",
            "expected": CodeRange.create((1, 0), (2, 0)),
        },
        "simple_pass_with_header_footer": {
            "code": "# header\npass # trailing\n# footer\n",
            "expected": CodeRange.create((1, 0), (4, 0)),
        },
    })
    def test_module_position(self, *, code: str, expected: CodeRange) -> None:
        module = parse_module(code)
        provider = SyntacticPositionProvider()
        module.code_for_node(module, provider)

        self.assertEqual(provider._computed[module], expected)

        # TODO: remove this
        self.assertEqual(module._metadata[SyntacticPositionProvider], expected)

    def cmp_position(self, actual: CodeRange, start: Tuple[int, int],
                     end: Tuple[int, int]) -> None:
        self.assertEqual(actual, CodeRange.create(start, end))

    def test_function_position(self) -> None:
        module = parse_module("def foo():\n    pass")
        provider = SyntacticPositionProvider()
        module.code_for_node(module, provider)

        fn = cast(cst.FunctionDef, module.body[0])
        stmt = cast(cst.SimpleStatementLine, fn.body.body[0])
        pass_stmt = cast(cst.Pass, stmt.body[0])
        self.cmp_position(provider._computed[stmt], (2, 4), (2, 8))
        self.cmp_position(provider._computed[pass_stmt], (2, 4), (2, 8))

    def test_nested_indent_position(self) -> None:
        module = parse_module(
            "if True:\n    if False:\n        x = 1\nelse:\n    return")
        provider = SyntacticPositionProvider()
        module.code_for_node(module, provider)

        outer_if = cast(cst.If, module.body[0])
        inner_if = cast(cst.If, outer_if.body.body[0])
        assign = cast(cst.SimpleStatementLine, inner_if.body.body[0]).body[0]

        outer_else = cast(cst.Else, outer_if.orelse)
        return_stmt = cast(cst.SimpleStatementLine,
                           outer_else.body.body[0]).body[0]

        self.cmp_position(provider._computed[outer_if], (1, 0), (5, 10))
        self.cmp_position(provider._computed[inner_if], (2, 4), (3, 13))
        self.cmp_position(provider._computed[assign], (3, 8), (3, 13))
        self.cmp_position(provider._computed[outer_else], (4, 0), (5, 10))
        self.cmp_position(provider._computed[return_stmt], (5, 4), (5, 10))

    def test_multiline_string_position(self) -> None:
        module = parse_module('"abc"\\\n"def"')
        provider = SyntacticPositionProvider()
        module.code_for_node(module, provider)

        stmt = cast(cst.SimpleStatementLine, module.body[0])
        expr = cast(cst.Expr, stmt.body[0])
        string = expr.value

        self.cmp_position(provider._computed[stmt], (1, 0), (2, 5))
        self.cmp_position(provider._computed[expr], (1, 0), (2, 5))
        self.cmp_position(provider._computed[string], (1, 0), (2, 5))
Exemplo n.º 11
0
 def test_repr(self) -> None:
     self.assertEqual(
         repr(
             cst.SimpleStatementLine(
                 body=[cst.Pass()],
                 # tuple with multiple items
                 leading_lines=(
                     cst.EmptyLine(
                         indent=True,
                         whitespace=cst.SimpleWhitespace(""),
                         comment=None,
                         newline=cst.Newline(),
                     ),
                     cst.EmptyLine(
                         indent=True,
                         whitespace=cst.SimpleWhitespace(""),
                         comment=None,
                         newline=cst.Newline(),
                     ),
                 ),
                 trailing_whitespace=cst.TrailingWhitespace(
                     whitespace=cst.SimpleWhitespace(" "),
                     comment=cst.Comment("# comment"),
                     newline=cst.Newline(),
                 ),
             )
         ),
         dedent(
             """
             SimpleStatementLine(
                 body=[
                     Pass(
                         semicolon=MaybeSentinel.DEFAULT,
                     ),
                 ],
                 leading_lines=[
                     EmptyLine(
                         indent=True,
                         whitespace=SimpleWhitespace(
                             value='',
                         ),
                         comment=None,
                         newline=Newline(
                             value=None,
                         ),
                     ),
                     EmptyLine(
                         indent=True,
                         whitespace=SimpleWhitespace(
                             value='',
                         ),
                         comment=None,
                         newline=Newline(
                             value=None,
                         ),
                     ),
                 ],
                 trailing_whitespace=TrailingWhitespace(
                     whitespace=SimpleWhitespace(
                         value=' ',
                     ),
                     comment=Comment(
                         value='# comment',
                     ),
                     newline=Newline(
                         value=None,
                     ),
                 ),
             )
             """
         ).strip(),
     )
Exemplo n.º 12
0
class WhitespaceParserTest(UnitTest):
    @data_provider(
        {
            "simple_whitespace_empty": {
                "parser": parse_simple_whitespace,
                "config": Config(
                    lines=["not whitespace\n", " another line\n"], default_newline="\n"
                ),
                "start_state": State(
                    line=1, column=0, absolute_indent="", is_parenthesized=False
                ),
                "end_state": State(
                    line=1, column=0, absolute_indent="", is_parenthesized=False
                ),
                "expected_node": cst.SimpleWhitespace(""),
            },
            "simple_whitespace_start_of_line": {
                "parser": parse_simple_whitespace,
                "config": Config(
                    lines=["\t  <-- There's some whitespace there\n"],
                    default_newline="\n",
                ),
                "start_state": State(
                    line=1, column=0, absolute_indent="", is_parenthesized=False
                ),
                "end_state": State(
                    line=1, column=3, absolute_indent="", is_parenthesized=False
                ),
                "expected_node": cst.SimpleWhitespace("\t  "),
            },
            "simple_whitespace_end_of_line": {
                "parser": parse_simple_whitespace,
                "config": Config(lines=["prefix   "], default_newline="\n"),
                "start_state": State(
                    line=1, column=6, absolute_indent="", is_parenthesized=False
                ),
                "end_state": State(
                    line=1, column=9, absolute_indent="", is_parenthesized=False
                ),
                "expected_node": cst.SimpleWhitespace("   "),
            },
            "simple_whitespace_line_continuation": {
                "parser": parse_simple_whitespace,
                "config": Config(
                    lines=["prefix \\\n", "    \\\n", "    # suffix\n"],
                    default_newline="\n",
                ),
                "start_state": State(
                    line=1, column=6, absolute_indent="", is_parenthesized=False
                ),
                "end_state": State(
                    line=3, column=4, absolute_indent="", is_parenthesized=False
                ),
                "expected_node": cst.SimpleWhitespace(" \\\n    \\\n    "),
            },
            "empty_lines_empty_list": {
                "parser": parse_empty_lines,
                "config": Config(
                    lines=["this is not an empty line"], default_newline="\n"
                ),
                "start_state": State(
                    line=1, column=0, absolute_indent="", is_parenthesized=False
                ),
                "end_state": State(
                    line=1, column=0, absolute_indent="", is_parenthesized=False
                ),
                "expected_node": [],
            },
            "empty_lines_single_line": {
                "parser": parse_empty_lines,
                "config": Config(
                    lines=["    # comment\n", "this is not an empty line\n"],
                    default_newline="\n",
                ),
                "start_state": State(
                    line=1, column=0, absolute_indent="    ", is_parenthesized=False
                ),
                "end_state": State(
                    line=2, column=0, absolute_indent="    ", is_parenthesized=False
                ),
                "expected_node": [
                    cst.EmptyLine(
                        indent=True,
                        whitespace=cst.SimpleWhitespace(""),
                        comment=cst.Comment("# comment"),
                        newline=cst.Newline(),
                    )
                ],
            },
            "empty_lines_multiple": {
                "parser": parse_empty_lines,
                "config": Config(
                    lines=[
                        "\n",
                        "    \n",
                        "     # comment with indent and whitespace\n",
                        "# comment without indent\n",
                        "  # comment with no indent but some whitespace\n",
                    ],
                    default_newline="\n",
                ),
                "start_state": State(
                    line=1, column=0, absolute_indent="    ", is_parenthesized=False
                ),
                "end_state": State(
                    line=5, column=47, absolute_indent="    ", is_parenthesized=False
                ),
                "expected_node": [
                    cst.EmptyLine(
                        indent=False,
                        whitespace=cst.SimpleWhitespace(""),
                        comment=None,
                        newline=cst.Newline(),
                    ),
                    cst.EmptyLine(
                        indent=True,
                        whitespace=cst.SimpleWhitespace(""),
                        comment=None,
                        newline=cst.Newline(),
                    ),
                    cst.EmptyLine(
                        indent=True,
                        whitespace=cst.SimpleWhitespace(" "),
                        comment=cst.Comment("# comment with indent and whitespace"),
                        newline=cst.Newline(),
                    ),
                    cst.EmptyLine(
                        indent=False,
                        whitespace=cst.SimpleWhitespace(""),
                        comment=cst.Comment("# comment without indent"),
                        newline=cst.Newline(),
                    ),
                    cst.EmptyLine(
                        indent=False,
                        whitespace=cst.SimpleWhitespace("  "),
                        comment=cst.Comment(
                            "# comment with no indent but some whitespace"
                        ),
                        newline=cst.Newline(),
                    ),
                ],
            },
            "empty_lines_non_default_newline": {
                "parser": parse_empty_lines,
                "config": Config(lines=["\n", "\r\n", "\r"], default_newline="\n"),
                "start_state": State(
                    line=1, column=0, absolute_indent="", is_parenthesized=False
                ),
                "end_state": State(
                    line=3, column=1, absolute_indent="", is_parenthesized=False
                ),
                "expected_node": [
                    cst.EmptyLine(
                        indent=True,
                        whitespace=cst.SimpleWhitespace(""),
                        comment=None,
                        newline=cst.Newline(None),  # default newline
                    ),
                    cst.EmptyLine(
                        indent=True,
                        whitespace=cst.SimpleWhitespace(""),
                        comment=None,
                        newline=cst.Newline("\r\n"),  # non-default
                    ),
                    cst.EmptyLine(
                        indent=True,
                        whitespace=cst.SimpleWhitespace(""),
                        comment=None,
                        newline=cst.Newline("\r"),  # non-default
                    ),
                ],
            },
            "trailing_whitespace": {
                "parser": parse_trailing_whitespace,
                "config": Config(
                    lines=["some code  # comment\n"], default_newline="\n"
                ),
                "start_state": State(
                    line=1, column=9, absolute_indent="", is_parenthesized=False
                ),
                "end_state": State(
                    line=1, column=21, absolute_indent="", is_parenthesized=False
                ),
                "expected_node": cst.TrailingWhitespace(
                    whitespace=cst.SimpleWhitespace("  "),
                    comment=cst.Comment("# comment"),
                    newline=cst.Newline(),
                ),
            },
        }
    )
    def test_parsers(
        self,
        parser: Callable[[Config, State], _T],
        config: Config,
        start_state: State,
        end_state: State,
        expected_node: _T,
    ) -> None:
        # Uses internal `deep_equals` function instead of `CSTNode.deep_equals`, because
        # we need to compare sequences of nodes, and this is the easiest way. :/
        parsed_node = parser(config, start_state)
        self.assertTrue(
            deep_equals(parsed_node, expected_node),
            msg=f"\n{parsed_node!r}\nis not deeply equal to \n{expected_node!r}",
        )
        self.assertEqual(start_state, end_state)
Exemplo n.º 13
0
from itertools import takewhile, dropwhile
from typing import Union, Dict, List, Optional, Tuple

import libcst as cst
from libcst import matchers as m
from libcst.helpers import parse_template_statement

from pybetter.transformers.base import NoqaAwareTransformer

DEFAULT_INIT_TEMPLATE = """if {arg} is None:
    {arg} = {init}
"""
# If you do not explicitly set `indent` to False, then even empty line
# will contain at least one indent worth of whitespaces.
EMPTY_LINE = cst.EmptyLine(indent=False, newline=cst.Newline())


def is_docstring(node):
    return m.matches(
        node, m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]))


class ArgEmptyInitTransformer(NoqaAwareTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.module_config = None

    def visit_Module(self, node: cst.Module) -> Optional[bool]:
        self.module_config = node.config_for_parsing
        return True
Exemplo n.º 14
0
class WithTest(CSTNodeTest):
    maxDiff: int = 2000

    @data_provider((
        # Simple with block
        {
            "node":
            cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
            ),
            "code":
            "with context_mgr(): pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 24)),
        },
        # Simple async with block
        {
            "node":
            cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async with context_mgr(): pass\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.7")),
        },
        # Python 3.6 async with block
        {
            "node":
            cst.FunctionDef(
                cst.Name("foo"),
                cst.Parameters(),
                cst.IndentedBlock((cst.With(
                    (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                    cst.SimpleStatementSuite((cst.Pass(), )),
                    asynchronous=cst.Asynchronous(),
                ), )),
                asynchronous=cst.Asynchronous(),
            ),
            "code":
            "async def foo():\n    async with context_mgr(): pass\n",
            "parser":
            lambda code: parse_statement(
                code, config=PartialParserConfig(python_version="3.6")),
        },
        # Multiple context managers
        {
            "node":
            cst.With(
                (
                    cst.WithItem(cst.Call(cst.Name("foo"))),
                    cst.WithItem(cst.Call(cst.Name("bar"))),
                ),
                cst.SimpleStatementSuite((cst.Pass(), )),
            ),
            "code":
            "with foo(), bar(): pass\n",
            "parser":
            None,
        },
        {
            "node":
            cst.With(
                (
                    cst.WithItem(
                        cst.Call(cst.Name("foo")),
                        comma=cst.Comma(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                    ),
                    cst.WithItem(cst.Call(cst.Name("bar"))),
                ),
                cst.SimpleStatementSuite((cst.Pass(), )),
            ),
            "code":
            "with foo(), bar(): pass\n",
            "parser":
            parse_statement,
        },
        # With block containing variable for context manager.
        {
            "node":
            cst.With(
                (cst.WithItem(
                    cst.Call(cst.Name("context_mgr")),
                    cst.AsName(cst.Name("ctx")),
                ), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
            ),
            "code":
            "with context_mgr() as ctx: pass\n",
            "parser":
            parse_statement,
        },
        # indentation
        {
            "node":
            DummyIndentedBlock(
                "    ",
                cst.With(
                    (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                    cst.SimpleStatementSuite((cst.Pass(), )),
                ),
            ),
            "code":
            "    with context_mgr(): pass\n",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 4), (1, 28)),
        },
        # with an indented body
        {
            "node":
            DummyIndentedBlock(
                "    ",
                cst.With(
                    (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                    cst.IndentedBlock((cst.SimpleStatementLine(
                        (cst.Pass(), )), )),
                ),
            ),
            "code":
            "    with context_mgr():\n        pass\n",
            "parser":
            None,
            "expected_position":
            CodeRange((1, 4), (2, 12)),
        },
        # leading_lines
        {
            "node":
            cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                leading_lines=(cst.EmptyLine(
                    comment=cst.Comment("# leading comment")), ),
            ),
            "code":
            "# leading comment\nwith context_mgr(): pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((2, 0), (2, 24)),
        },
        # Whitespace
        {
            "node":
            cst.With(
                (cst.WithItem(
                    cst.Call(cst.Name("context_mgr")),
                    cst.AsName(
                        cst.Name("ctx"),
                        whitespace_before_as=cst.SimpleWhitespace("  "),
                        whitespace_after_as=cst.SimpleWhitespace("  "),
                    ),
                ), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_with=cst.SimpleWhitespace("  "),
                whitespace_before_colon=cst.SimpleWhitespace("  "),
            ),
            "code":
            "with  context_mgr()  as  ctx  : pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 36)),
        },
        # Weird spacing rules, that parse differently depending on whether
        # we are using a grammar that included parenthesized with statements.
        {
            "node":
            cst.With(
                (cst.WithItem(
                    cst.Call(
                        cst.Name("context_mgr"),
                        lpar=() if is_native() else (cst.LeftParen(), ),
                        rpar=() if is_native() else (cst.RightParen(), ),
                    )), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=(cst.LeftParen()
                      if is_native() else MaybeSentinel.DEFAULT),
                rpar=(cst.RightParen()
                      if is_native() else MaybeSentinel.DEFAULT),
                whitespace_after_with=cst.SimpleWhitespace(""),
            ),
            "code":
            "with(context_mgr()): pass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange((1, 0), (1, 25)),
        },
        # Multi-line parenthesized with.
        {
            "node":
            cst.With(
                (
                    cst.WithItem(
                        cst.Call(cst.Name("foo")),
                        comma=cst.
                        Comma(whitespace_after=cst.ParenthesizedWhitespace(
                            first_line=cst.TrailingWhitespace(
                                whitespace=cst.SimpleWhitespace(value="", ),
                                comment=None,
                                newline=cst.Newline(value=None, ),
                            ),
                            empty_lines=[],
                            indent=True,
                            last_line=cst.SimpleWhitespace(value="       ", ),
                        )),
                    ),
                    cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()),
                ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),
                rpar=cst.RightParen(
                    whitespace_before=cst.SimpleWhitespace(" ")),
            ),
            "code": ("with ( foo(),\n"
                     "       bar(), ): pass\n"),  # noqa
            "parser":
            parse_statement if is_native() else None,
            "expected_position":
            CodeRange((1, 0), (2, 21)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)

    @data_provider((
        {
            "get_node":
            lambda: cst.With((),
                             cst.IndentedBlock((cst.SimpleStatementLine(
                                 (cst.Pass(), )), ))),
            "expected_re":
            "A With statement must have at least one WithItem",
        },
        {
            "get_node":
            lambda: cst.With(
                (cst.WithItem(
                    cst.Call(cst.Name("foo")),
                    comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")
                                    ),
                ), ),
                cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(), )), )),
            ),
            "expected_re":
            "The last WithItem in an unparenthesized With cannot " +
            "have a trailing comma.",
        },
        {
            "get_node":
            lambda: cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_with=cst.SimpleWhitespace(""),
            ),
            "expected_re":
            "Must have at least one space after with keyword",
        },
        {
            "get_node":
            lambda: cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_with=cst.SimpleWhitespace(""),
                lpar=cst.LeftParen(),
            ),
            "expected_re":
            "Do not mix concrete LeftParen/RightParen with " + "MaybeSentinel",
        },
        {
            "get_node":
            lambda: cst.With(
                (cst.WithItem(cst.Call(cst.Name("context_mgr"))), ),
                cst.SimpleStatementSuite((cst.Pass(), )),
                whitespace_after_with=cst.SimpleWhitespace(""),
                rpar=cst.RightParen(),
            ),
            "expected_re":
            "Do not mix concrete LeftParen/RightParen with " + "MaybeSentinel",
        },
    ))
    def test_invalid(self, **kwargs: Any) -> None:
        self.assert_invalid(**kwargs)

    @data_provider((
        {
            "code": "with a, b: pass",
            "parser": parse_statement_as(python_version="3.1"),
            "expect_success": True,
        },
        {
            "code": "with a, b: pass",
            "parser": parse_statement_as(python_version="3.0"),
            "expect_success": False,
        },
    ))
    def test_versions(self, **kwargs: Any) -> None:
        if is_native() and not kwargs.get("expect_success", True):
            self.skipTest("parse errors are disabled for native parser")
        self.assert_parses(**kwargs)

    def test_adding_parens(self) -> None:
        node = cst.With(
            (
                cst.WithItem(
                    cst.Call(cst.Name("foo")),
                    comma=cst.Comma(
                        whitespace_after=cst.ParenthesizedWhitespace(), ),
                ),
                cst.WithItem(cst.Call(cst.Name("bar")), comma=cst.Comma()),
            ),
            cst.SimpleStatementSuite((cst.Pass(), )),
            lpar=cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),
            rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),
        )
        module = cst.Module([])
        self.assertEqual(
            module.code_for_node(node),
            ("with ( foo(),\n"
             "bar(), ): pass\n")  # noqa
        )