Example #1
0
    def test_deep_replace_complex(self) -> None:
        old_code = """
            def a():
                def b():
                    def c():
                        pass
        """
        new_code = """
            def a():
                def b():
                    def d(): break
        """

        module = cst.parse_module(dedent(old_code))
        outer_fun = cst.ensure_type(module.body[0], cst.FunctionDef)
        middle_fun = cst.ensure_type(
            cst.ensure_type(outer_fun.body, cst.IndentedBlock).body[0],
            cst.FunctionDef)
        inner_fun = cst.ensure_type(
            cst.ensure_type(middle_fun.body, cst.IndentedBlock).body[0],
            cst.FunctionDef)
        new_module = cst.ensure_type(
            module.deep_replace(
                inner_fun,
                cst.FunctionDef(
                    name=cst.Name("d"),
                    params=cst.Parameters(),
                    body=cst.SimpleStatementSuite(body=(cst.Break(), )),
                ),
            ),
            cst.Module,
        )
        self.assertEqual(new_module.code, dedent(new_code))
Example #2
0
    def test_deep_replace_simple(self) -> None:
        old_code = """
            pass
        """
        new_code = """
            break
        """

        module = cst.parse_module(dedent(old_code))
        pass_stmt = cst.ensure_type(module.body[0],
                                    cst.SimpleStatementLine).body[0]
        new_module = cst.ensure_type(
            module.deep_replace(pass_stmt, cst.Break()), cst.Module)
        self.assertEqual(new_module.code, dedent(new_code))
Example #3
0
    def test_deep_replace_identity(self) -> None:
        old_code = """
            pass
        """
        new_code = """
            break
        """

        module = cst.parse_module(dedent(old_code))
        new_module = module.deep_replace(
            module,
            cst.Module(
                header=(cst.EmptyLine(), ),
                body=(cst.SimpleStatementLine(body=(cst.Break(), )), ),
            ),
        )
        self.assertEqual(new_module.code, dedent(new_code))
class LeafSmallStatementsTest(CSTNodeTest):
    @data_provider(((cst.Pass(), "pass"), (cst.Break(), "break"),
                    (cst.Continue(), "continue")))
    def test_valid(self, node: cst.CSTNode, code: str) -> None:
        self.validate_node(node, code)
class SmallStatementTest(CSTNodeTest):
    @data_provider((
        # pyre-fixme[6]: Incompatible parameter type
        {
            "node": cst.Pass(),
            "code": "pass"
        },
        {
            "node": cst.Pass(semicolon=cst.Semicolon()),
            "code": "pass;"
        },
        {
            "node":
            cst.Pass(semicolon=cst.Semicolon(
                whitespace_before=cst.SimpleWhitespace("  "),
                whitespace_after=cst.SimpleWhitespace("    "),
            )),
            "code":
            "pass  ;    ",
            "expected_position":
            CodeRange((1, 0), (1, 4)),
        },
        {
            "node": cst.Continue(),
            "code": "continue"
        },
        {
            "node": cst.Continue(semicolon=cst.Semicolon()),
            "code": "continue;"
        },
        {
            "node":
            cst.Continue(semicolon=cst.Semicolon(
                whitespace_before=cst.SimpleWhitespace("  "),
                whitespace_after=cst.SimpleWhitespace("    "),
            )),
            "code":
            "continue  ;    ",
            "expected_position":
            CodeRange((1, 0), (1, 8)),
        },
        {
            "node": cst.Break(),
            "code": "break"
        },
        {
            "node": cst.Break(semicolon=cst.Semicolon()),
            "code": "break;"
        },
        {
            "node":
            cst.Break(semicolon=cst.Semicolon(
                whitespace_before=cst.SimpleWhitespace("  "),
                whitespace_after=cst.SimpleWhitespace("    "),
            )),
            "code":
            "break  ;    ",
            "expected_position":
            CodeRange((1, 0), (1, 5)),
        },
        {
            "node":
            cst.Expr(
                cst.BinaryOperation(cst.Name("x"), cst.Add(), cst.Name("y"))),
            "code":
            "x + y",
        },
        {
            "node":
            cst.Expr(
                cst.BinaryOperation(cst.Name("x"), cst.Add(), cst.Name("y")),
                semicolon=cst.Semicolon(),
            ),
            "code":
            "x + y;",
        },
        {
            "node":
            cst.Expr(
                cst.BinaryOperation(cst.Name("x"), cst.Add(), cst.Name("y")),
                semicolon=cst.Semicolon(
                    whitespace_before=cst.SimpleWhitespace("  "),
                    whitespace_after=cst.SimpleWhitespace("    "),
                ),
            ),
            "code":
            "x + y  ;    ",
            "expected_position":
            CodeRange((1, 0), (1, 5)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)
Example #6
0
class SimpleStatementTest(CSTNodeTest):
    @data_provider((
        # a single-element SimpleStatementLine
        # pyre-fixme[6]: Incompatible parameter type
        {
            "node": cst.SimpleStatementLine((cst.Pass(), )),
            "code": "pass\n",
            "parser": parse_statement,
        },
        # a multi-element SimpleStatementLine
        {
            "node":
            cst.SimpleStatementLine(
                (cst.Pass(semicolon=cst.Semicolon()), cst.Continue())),
            "code":
            "pass;continue\n",
            "parser":
            parse_statement,
        },
        # a multi-element SimpleStatementLine with whitespace
        {
            "node":
            cst.SimpleStatementLine((
                cst.Pass(semicolon=cst.Semicolon(
                    whitespace_before=cst.SimpleWhitespace(" "),
                    whitespace_after=cst.SimpleWhitespace("  "),
                )),
                cst.Continue(),
            )),
            "code":
            "pass ;  continue\n",
            "parser":
            parse_statement,
        },
        # A more complicated SimpleStatementLine
        {
            "node":
            cst.SimpleStatementLine((
                cst.Pass(semicolon=cst.Semicolon()),
                cst.Continue(semicolon=cst.Semicolon()),
                cst.Break(),
            )),
            "code":
            "pass;continue;break\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange.create((1, 0), (1, 19)),
        },
        # a multi-element SimpleStatementLine, inferred semicolons
        {
            "node":
            cst.SimpleStatementLine((cst.Pass(), cst.Continue(), cst.Break())),
            "code":
            "pass; continue; break\n",
            "parser":
            None,  # No test for parsing, since we are using sentinels.
        },
        # some expression statements
        {
            "node": cst.SimpleStatementLine((cst.Expr(cst.Name("None")), )),
            "code": "None\n",
            "parser": parse_statement,
        },
        {
            "node": cst.SimpleStatementLine((cst.Expr(cst.Name("True")), )),
            "code": "True\n",
            "parser": parse_statement,
        },
        {
            "node": cst.SimpleStatementLine((cst.Expr(cst.Name("False")), )),
            "code": "False\n",
            "parser": parse_statement,
        },
        {
            "node": cst.SimpleStatementLine((cst.Expr(cst.Ellipsis()), )),
            "code": "...\n",
            "parser": parse_statement,
        },
        # Test some numbers
        {
            "node": cst.SimpleStatementLine((cst.Expr(cst.Integer("5")), )),
            "code": "5\n",
            "parser": parse_statement,
        },
        {
            "node": cst.SimpleStatementLine((cst.Expr(cst.Float("5.5")), )),
            "code": "5.5\n",
            "parser": parse_statement,
        },
        {
            "node": cst.SimpleStatementLine((cst.Expr(cst.Imaginary("5j")), )),
            "code": "5j\n",
            "parser": parse_statement,
        },
        # Test some numbers with parens
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.Integer("5",
                            lpar=(cst.LeftParen(), ),
                            rpar=(cst.RightParen(), ))), )),
            "code":
            "(5)\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange.create((1, 0), (1, 3)),
        },
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.Float("5.5",
                          lpar=(cst.LeftParen(), ),
                          rpar=(cst.RightParen(), ))), )),
            "code":
            "(5.5)\n",
            "parser":
            parse_statement,
        },
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.Imaginary("5j",
                              lpar=(cst.LeftParen(), ),
                              rpar=(cst.RightParen(), ))), )),
            "code":
            "(5j)\n",
            "parser":
            parse_statement,
        },
        # Test some strings
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(cst.SimpleString('"abc"')), )),
            "code":
            '"abc"\n',
            "parser":
            parse_statement,
        },
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.ConcatenatedString(cst.SimpleString('"abc"'),
                                       cst.SimpleString('"def"'))), )),
            "code":
            '"abc""def"\n',
            "parser":
            parse_statement,
        },
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.ConcatenatedString(
                    left=cst.SimpleString('"abc"'),
                    whitespace_between=cst.SimpleWhitespace(" "),
                    right=cst.ConcatenatedString(
                        left=cst.SimpleString('"def"'),
                        whitespace_between=cst.SimpleWhitespace(" "),
                        right=cst.SimpleString('"ghi"'),
                    ),
                )), )),
            "code":
            '"abc" "def" "ghi"\n',
            "parser":
            parse_statement,
            "expected_position":
            CodeRange.create((1, 0), (1, 17)),
        },
        # Test parenthesis rules
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.Ellipsis(lpar=(cst.LeftParen(), ),
                             rpar=(cst.RightParen(), ))), )),
            "code":
            "(...)\n",
            "parser":
            parse_statement,
        },
        # Test parenthesis with whitespace ownership
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.Ellipsis(
                    lpar=(cst.LeftParen(
                        whitespace_after=cst.SimpleWhitespace(" ")), ),
                    rpar=(cst.RightParen(
                        whitespace_before=cst.SimpleWhitespace(" ")), ),
                )), )),
            "code":
            "( ... )\n",
            "parser":
            parse_statement,
        },
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.Ellipsis(
                    lpar=(
                        cst.LeftParen(
                            whitespace_after=cst.SimpleWhitespace(" ")),
                        cst.LeftParen(
                            whitespace_after=cst.SimpleWhitespace("  ")),
                        cst.LeftParen(
                            whitespace_after=cst.SimpleWhitespace("   ")),
                    ),
                    rpar=(
                        cst.RightParen(
                            whitespace_before=cst.SimpleWhitespace("   ")),
                        cst.RightParen(
                            whitespace_before=cst.SimpleWhitespace("  ")),
                        cst.RightParen(
                            whitespace_before=cst.SimpleWhitespace(" ")),
                    ),
                )), )),
            "code":
            "( (  (   ...   )  ) )\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange.create((1, 0), (1, 21)),
        },
        # Test parenthesis rules with expressions
        {
            "node":
            cst.SimpleStatementLine((cst.Expr(
                cst.Ellipsis(
                    lpar=(cst.LeftParen(
                        whitespace_after=cst.ParenthesizedWhitespace(
                            first_line=cst.TrailingWhitespace(),
                            empty_lines=(cst.EmptyLine(
                                comment=cst.Comment("# Wow, a comment!")), ),
                            indent=True,
                            last_line=cst.SimpleWhitespace("    "),
                        )), ),
                    rpar=(cst.RightParen(
                        whitespace_before=cst.ParenthesizedWhitespace(
                            first_line=cst.TrailingWhitespace(),
                            empty_lines=(),
                            indent=True,
                            last_line=cst.SimpleWhitespace(""),
                        )), ),
                )), )),
            "code":
            "(\n# Wow, a comment!\n    ...\n)\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange.create((1, 0), (4, 1)),
        },
        # test trailing whitespace
        {
            "node":
            cst.SimpleStatementLine(
                (cst.Pass(), ),
                trailing_whitespace=cst.TrailingWhitespace(
                    whitespace=cst.SimpleWhitespace("  "),
                    comment=cst.Comment("# trailing comment"),
                ),
            ),
            "code":
            "pass  # trailing comment\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange.create((1, 0), (1, 4)),
        },
        # test leading comment
        {
            "node":
            cst.SimpleStatementLine(
                (cst.Pass(), ),
                leading_lines=(cst.EmptyLine(
                    comment=cst.Comment("# comment")), ),
            ),
            "code":
            "# comment\npass\n",
            "parser":
            parse_statement,
            "expected_position":
            CodeRange.create((2, 0), (2, 4)),
        },
        # test indentation
        {
            "node":
            DummyIndentedBlock(
                "    ",
                cst.SimpleStatementLine(
                    (cst.Pass(), ),
                    leading_lines=(cst.EmptyLine(
                        comment=cst.Comment("# comment")), ),
                ),
            ),
            "code":
            "    # comment\n    pass\n",
            "expected_position":
            CodeRange.create((2, 4), (2, 8)),
        },
        # test suite variant
        {
            "node": cst.SimpleStatementSuite((cst.Pass(), )),
            "code": " pass\n",
            "expected_position": CodeRange.create((1, 1), (1, 5)),
        },
        {
            "node":
            cst.SimpleStatementSuite(
                (cst.Pass(), ), leading_whitespace=cst.SimpleWhitespace("")),
            "code":
            "pass\n",
            "expected_position":
            CodeRange.create((1, 0), (1, 4)),
        },
    ))
    def test_valid(self, **kwargs: Any) -> None:
        self.validate_node(**kwargs)