예제 #1
0
 def test_has_side_effects(
     self,
     get_import_from_path,
     get_import_path,
     safe_read,
     parse_ast,
     visit,
     init,
     has_side_effects,
     get_import_return,
     safe_read_return,
     safe_read_raise,
     parse_ast_return,
     parse_ast_raise,
     has_side_effects_return,
     has_side_effects_raise,
 ):
     init.return_value = None
     get_import_from_path.return_value = get_import_return
     get_import_path.return_value = get_import_return
     safe_read.return_value = safe_read_return
     safe_read.side_effect = safe_read_raise
     parse_ast.return_value = parse_ast_return
     parse_ast.side_effect = parse_ast_raise
     has_side_effects.return_value = has_side_effects_return
     has_side_effects.side_effect = has_side_effects_raise
     with sysu.std_redirect(sysu.STD.ERR):
         node = Import(NodeLocation((1, 0), 1), [])
         val = self.session_maker._has_side_effects("", node)
         assert val == has_side_effects_return
예제 #2
0
 def _assert_import_equal(
     self, impt_stmnt: str, endlineno: int, used_names: set, expec_impt: str
 ):
     location = NodeLocation((1, 0), endlineno)
     transformer = transform.ImportTransformer(used_names, location)
     cst_tree = cst.parse_module(impt_stmnt)
     assert cst_tree.visit(transformer).code == expec_impt
예제 #3
0
 def test_refactor(
     self,
     skip_import,
     _expand_import_star,
     _get_used_names,
     _transform,
     skip_import_return,
     _expand_import_star_return,
     _get_used_names_return,
     _transform_return,
     expand_stars,
     mode,
     original_lines,
     expec_fixed_lines,
 ):
     skip_import.return_value = skip_import_return
     _expand_import_star.return_value = _expand_import_star_return
     _get_used_names.return_value = _get_used_names_return
     _transform.return_value = _transform_return
     setattr(self.configs, "expand_stars", expand_stars)
     setattr(self.configs, mode, True)
     node = Import(NodeLocation((1, 0), 1),
                   [ast.alias(name="x", asname=None)])
     self.session_maker._import_stats = ImportStats({node}, set())
     with sysu.std_redirect(sysu.STD.OUT):
         with sysu.std_redirect(sysu.STD.ERR):
             fixed_code = self.session_maker._refactor(original_lines)
             assert fixed_code == "".join(expec_fixed_lines)
예제 #4
0
 def test_rebuild_import_invalid_syntax(self, init):
     init.return_value = None
     transform.rebuild_import(
         "@invalid_syntax",
         {""},
         Path(__file__),
         NodeLocation((1, 0), 0),
     )
예제 #5
0
 def test_stylize(self, code, endlineno, ismultiline):
     location = NodeLocation((1, 0), endlineno)
     node = cst.parse_module(code).body[0].body[0]
     transformer = transform.ImportTransformer({""}, location)
     new_node = transformer._stylize(node, node.names, False)
     if getattr(new_node, "rpar", None) and ismultiline:
         assert new_node.rpar != node.rpar
         assert new_node.lpar != node.lpar
     assert new_node.names[-1].comma == cst.MaybeSentinel.DEFAULT
예제 #6
0
 def test_expand_import_star(
     self,
     expand_import_star,
     expand_import_star_raise,
     name,
     expec_is_star,
 ):
     node = ImportFrom(NodeLocation((1, 0), 1), [ast.alias(name=name)],
                       "xxx", 0)
     expand_import_star.return_value = node
     expand_import_star.side_effect = expand_import_star_raise
     enode, is_star = self.session_maker._expand_import_star(node)
     assert (enode, is_star) == (node, expec_is_star)
예제 #7
0
 def test_refactor_skipping(
     self,
     _get_used_names,
     endline_no,
     original_lines,
     expec_fixed_lines,
 ):
     _get_used_names.return_value = set()
     setattr(self.configs, "expand_stars", False)
     node = Import(NodeLocation((1, 0), endline_no),
                   [ast.alias(name="x", asname=None)])
     self.session_maker._import_stats = ImportStats({node}, set())
     fixed_code = self.session_maker._refactor(original_lines)
     assert fixed_code == "".join(expec_fixed_lines)
예제 #8
0
 def test_should_remove(
     self,
     _has_used,
     _has_side_effects,
     _has_used_return,
     _has_side_effects_return,
     all_,
     name,
     expec_val,
 ):
     _has_used.return_value = _has_used_return
     _has_side_effects.return_value = _has_side_effects_return
     setattr(self.configs, "all_", all_)
     alias = ast.alias(name=name, asname=None)
     node = Import(NodeLocation((1, 0), 1), [alias])
     val = self.session_maker._should_remove(node, alias, False)
     assert val == expec_val
예제 #9
0
 def test_rebuild_import(
     self,
     parse_module,
     init,
     import_stmnt,
     col_offset,
     used_names,
     expec_fixed_code,
     expec_fixed_lines,
     expec_err,
 ):
     with pytest.raises(expec_err):
         init.return_value = None
         parse_module.return_value.visit.return_value.code = expec_fixed_code
         fixed_lines = transform.rebuild_import(
             import_stmnt,
             used_names,
             Path(__file__),
             NodeLocation((1, col_offset), 0),
         )
         assert fixed_lines == expec_fixed_lines
         raise sysu.Pass()
예제 #10
0
    def test_refactor_init_without_all(
        self,
        _get_used_names,
        _expand_import_star,
        _get_used_names_return,  # we only care about the len.
        is_star_return,
        _is_init_without_all,
        is_undecidable,
    ):
        node = Import(NodeLocation((1, 0), 1),
                      [ast.alias(name="x", asname=None)])
        _get_used_names.return_value = _get_used_names_return
        _expand_import_star.return_value = node, is_star_return
        self.session_maker._import_stats = ImportStats({node}, set())
        self.session_maker._is_init_without_all = _is_init_without_all

        with sysu.std_redirect(sysu.STD.ERR):
            self.session_maker._refactor(["import x"])  # Fake code

        if is_undecidable:
            assert self.session_maker.reporter._undecidable_case == 1
        else:
            assert not self.session_maker.reporter._undecidable_case
예제 #11
0
class TestRefactor:
    """`Refactor` methods test case."""
    def setup_method(self, method):
        self.configs = config.Config(paths=[Path("")])
        self.reporter = report.Report(self.configs)
        self.session_maker = refactor.Refactor(self.configs, self.reporter)

    @pytest.mark.parametrize(
        "source_lines, expec_lines",
        [
            pytest.param(
                [
                    "try:\n",
                    "    pass\n",
                    "except:\n",
                    "    import y\n",
                ],
                [
                    "try:\n",
                    "    pass\n",
                    "except:\n",
                    "    import y\n",
                ],
                id="useful",
            ),
            pytest.param(
                [
                    "try:\n",
                    "    import x\n",
                    "    pass\n",
                    "except:\n",
                    "    import y\n",
                ],
                [
                    "try:\n",
                    "    import x\n",
                    "except:\n",
                    "    import y\n",
                ],
                id="single-useless",
            ),
            pytest.param(
                [
                    "try:\n",
                    "    import x\n",
                    "    pass\n",
                    "    pass\n",
                    "    pass\n",
                    "    pass\n",
                    "    pass\n",
                    "    pass\n",
                    "    pass\n",
                    "except:\n",
                    "    import y\n",
                ],
                [
                    "try:\n",
                    "    import x\n",
                    "except:\n",
                    "    import y\n",
                ],
                id="multi-useless0",
            ),
            pytest.param(
                [
                    "try:\n",
                    "    pass\n",
                    "    pass\n",
                    "    pass\n",
                    "    pass\n",
                    "    pass\n",
                    "    pass\n",
                    "    pass\n",
                    "    pass\n",
                    "except:\n",
                    "    import y\n",
                ],
                [
                    "try:\n",
                    "    pass\n",
                    "except:\n",
                    "    import y\n",
                ],
                id="multi-useless1",
            ),
            pytest.param(
                [
                    "def foo():\n",
                    "    '''docs'''\n",
                    "    pass\n",
                ],
                [
                    "def foo():\n",
                    "    '''docs'''\n",
                ],
                id="useless with docs",
            ),
            pytest.param(
                [
                    "x = i if i else y\n",
                ],
                [
                    "x = i if i else y\n",
                ],
                id="TypeError",
            ),
        ],
    )
    def test_remove_useless_passes(self, source_lines, expec_lines):
        fixed_code = refactor.Refactor.remove_useless_passes(source_lines)
        assert fixed_code == expec_lines

    @pytest.mark.parametrize(
        "safe_read_raise, _code_session_raise",
        [
            pytest.param(None, None, id="without errors"),
            pytest.param(ReadPermissionError(13, "", Path("")),
                         None,
                         id="ReadPermissionError"),
            pytest.param(WritePermissionError(13, "", Path("")),
                         None,
                         id="WritePermissionError"),
            pytest.param(None,
                         UnparsableFile(Path(""), SyntaxError("")),
                         id="UnparsableFile"),
        ],
    )
    @mock.patch(MOCK % "Refactor._output")
    @mock.patch(MOCK % "Refactor._code_session")
    @mock.patch(MOCK % "iou.safe_read")
    def test_session(self, safe_read, _code_session, _output, safe_read_raise,
                     _code_session_raise):
        safe_read.return_value = ("code...\ncode...\n", "utf-8", "\n")
        safe_read.side_effect = safe_read_raise
        _code_session.return_value = "code...\ncode...\n"
        _code_session.side_effect = _code_session_raise
        self.session_maker.session(Path("modified"))
        assert self.session_maker._path == Path("")

    @pytest.mark.parametrize(
        "skip_file_return, _analyze_return, expec_fixed_code",
        [
            pytest.param(True, None, "original.code", id="file skip"),
            pytest.param(False, None, "original.code", id="no stats"),
            pytest.param(False, ("s", "i"), "fixed.code", id="refactored"),
        ],
    )
    @mock.patch(MOCK % "Refactor._refactor")
    @mock.patch(MOCK % "Refactor._analyze")
    @mock.patch(MOCK % "scan.parse_ast")
    @mock.patch(MOCK % "regexu.skip_file")
    def test_code_session(
        self,
        skip_file,
        parse_ast,
        _analyze,
        _refactor,
        skip_file_return,
        _analyze_return,
        expec_fixed_code,
    ):
        skip_file.return_value = skip_file_return
        _analyze.return_value = _analyze_return
        _refactor.return_value = "fixed.code"
        with sysu.std_redirect(sysu.STD.ERR):
            fixed_code = self.session_maker._code_session("original.code")
            assert fixed_code == expec_fixed_code

    @pytest.mark.parametrize(
        "fixed_lines, original_lines, mode, expec_output",
        [
            pytest.param(["code..\n"], ["code..\n"],
                         "verbose",
                         "looks good!",
                         id="unchanged"),
            pytest.param(
                ["fixed.code..\n"],
                ["original.code..\n"],
                "check",
                "🚀",
                id="changed-check",
            ),
            pytest.param(
                ["import x\n"],
                ["import x, y\n"],
                "diff",
                "-import x, y\n+import x\n",
                id="changed-diff",
            ),
            pytest.param(
                ["import x\n"],
                ["import x, y\n"],
                "diff",
                "-import x, y\n+import x\n",
                id="changed-diff",
            ),
        ],
    )
    @mock.patch(MOCK % "Refactor.remove_useless_passes")
    def test_output(self, x, fixed_lines, original_lines, mode, expec_output):
        x.return_value = fixed_lines
        setattr(self.configs, mode, True)
        with sysu.std_redirect(sysu.STD.OUT) as stdout:
            self.session_maker._output(fixed_lines, original_lines, "utf-8",
                                       "\n")
            assert expec_output in stdout.getvalue()

    @mock.patch(MOCK % "Refactor.remove_useless_passes")
    def test_output_write(self, x):
        fixed_lines, original_lines = ["import x\n"], ["import x, y\n"]
        x.return_value = fixed_lines
        with sysu.reopenable_temp_file("".join(original_lines)) as tmp_path:
            with open(tmp_path) as tmp:
                self.session_maker._path = tmp_path
                self.session_maker._output(fixed_lines, original_lines,
                                           "utf-8", "\n")
                assert tmp.readlines() == fixed_lines

    @pytest.mark.parametrize(
        "get_stats_raise, expec_val",
        [
            pytest.param(None, ("", ""), id="normal"),
            pytest.param(Exception(""), None, id="error"),
        ],
    )
    @mock.patch(MOCK % "scan.SourceAnalyzer.get_stats")
    def test_analyze(self, get_stats, get_stats_raise, expec_val):
        get_stats.return_value = ("", "")
        get_stats.side_effect = get_stats_raise
        with sysu.std_redirect(sysu.STD.ERR):
            val = self.session_maker._analyze(ast.parse(""), [""])
            assert val == expec_val

    @pytest.mark.parametrize(
        ("skip_import_return, _expand_import_star_return, _get_used_names_return,"
         "_transform_return, expand_stars, mode, original_lines, expec_fixed_lines"
         ),
        [
            pytest.param(
                True,
                None,
                None,
                ["import x # nopycln: import"],
                False,
                "not-matter",
                ["import x, y # nopycln: import"],
                ["import x, y # nopycln: import"],
                id="nopycln",
            ),
            pytest.param(
                False,
                (None, None),
                None,
                ["import x, y"],
                False,
                "not-matter",
                ["import *"],
                ["import *"],
                id="unexpandable star",
            ),
            pytest.param(
                False,
                (None, True),
                {"x", "y"},
                ["import x, y"],
                False,
                "not-matter",
                ["import *"],
                ["import *"],
                id="star, used, no -x",
            ),
            pytest.param(
                False,
                (
                    Import(NodeLocation(
                        (1, 0), 1), [ast.alias(name="x", asname=None)]),
                    True,
                ),
                {"x", "y"},
                ["import x, y"],
                True,
                "not-matter",
                ["import *"],
                ["import x, y"],
                id="star, used, -x",
            ),
            pytest.param(
                False,
                (
                    Import(NodeLocation(
                        (1, 0), 1), [ast.alias(name="x", asname=None)]),
                    True,
                ),
                set(),
                [""],
                None,
                "not-matter",
                ["import *"],
                [""],
                id="star, not used",
            ),
            pytest.param(
                False,
                (
                    Import(NodeLocation(
                        (1, 0), 1), [ast.alias(name="x", asname=None)]),
                    False,
                ),
                set("x"),
                None,
                False,
                "not-matter",
                ["import x"],
                ["import x"],
                id="all used, no -x",
            ),
            pytest.param(
                False,
                (
                    Import(NodeLocation(
                        (1, 0), 1), [ast.alias(name="x", asname=None)]),
                    True,
                ),
                set("x"),
                ["import x, y"],
                True,
                "not-matter",
                ["import x"],
                ["import x, y"],
                id="all used, -x",
            ),
            pytest.param(
                False,
                (
                    Import(NodeLocation(
                        (1, 0), 1), [ast.alias(name="x", asname=None)]),
                    False,
                ),
                set("x"),
                ["import x"],
                True,
                "check",
                ["import x, y"],
                ["import x, y\n_CHANGED_"],
                id="check",
            ),
        ],
    )
    @mock.patch(MOCK % "Refactor._transform")
    @mock.patch(MOCK % "Refactor._get_used_names")
    @mock.patch(MOCK % "Refactor._expand_import_star")
    @mock.patch(MOCK % "regexu.skip_import")
    def test_refactor(
        self,
        skip_import,
        _expand_import_star,
        _get_used_names,
        _transform,
        skip_import_return,
        _expand_import_star_return,
        _get_used_names_return,
        _transform_return,
        expand_stars,
        mode,
        original_lines,
        expec_fixed_lines,
    ):
        skip_import.return_value = skip_import_return
        _expand_import_star.return_value = _expand_import_star_return
        _get_used_names.return_value = _get_used_names_return
        _transform.return_value = _transform_return
        setattr(self.configs, "expand_stars", expand_stars)
        setattr(self.configs, mode, True)
        node = Import(NodeLocation((1, 0), 1),
                      [ast.alias(name="x", asname=None)])
        self.session_maker._import_stats = ImportStats({node}, set())
        with sysu.std_redirect(sysu.STD.OUT):
            with sysu.std_redirect(sysu.STD.ERR):
                fixed_code = self.session_maker._refactor(original_lines)
                assert fixed_code == "".join(expec_fixed_lines)

    @pytest.mark.parametrize(
        "_should_remove_return, node, is_star, expec_names",
        [
            pytest.param(
                False,
                Import(
                    NodeLocation((1, 0), 1),
                    [
                        ast.alias(name="x", asname=None),
                        ast.alias(name="y", asname=None),
                    ],
                ),
                False,
                {"x", "y"},
                id="used",
            ),
            pytest.param(
                True,
                Import(
                    NodeLocation((1, 0), 1),
                    [
                        ast.alias(name="x", asname=None),
                        ast.alias(name="y", asname=None),
                    ],
                ),
                False,
                set(),
                id="not-used",
            ),
            pytest.param(
                True,
                Import(
                    NodeLocation((1, 0), 1),
                    [
                        ast.alias(name="x", asname=None),
                        ast.alias(name="y", asname=None),
                    ],
                ),
                True,
                set(),
                id="not-used, star",
            ),
        ],
    )
    @mock.patch(MOCK % "Refactor._should_remove")
    def test_get_used_names(self, _should_remove, _should_remove_return, node,
                            is_star, expec_names):
        _should_remove.return_value = _should_remove_return
        with sysu.std_redirect(sysu.STD.OUT):
            used_names = self.session_maker._get_used_names(node, is_star)
            assert used_names == expec_names

    @pytest.mark.parametrize(
        ("rebuild_import_return, rebuild_import_raise, "
         "location, original_lines, updated_lines"),
        [
            pytest.param(
                "import x\n",
                None,
                NodeLocation((1, 0), 1),
                ["import x, i\n", "import y\n"],
                ["import x\n", "import y\n"],
                id="normal",
            ),
            pytest.param(
                "import x\n",
                UnsupportedCase(Path(""), NodeLocation((1, 0), 1), ""),
                NodeLocation((1, 0), 1),
                ["import x, i\n", "import y\n"],
                ["import x\n", "import y\n"],
                id="UnparsableFile",
            ),
            pytest.param(
                "import x\n",
                ParserSyntaxError("", lines=[""], raw_line=1, raw_column=0),
                NodeLocation((1, 0), 1),
                ["import x; import y\n"],
                ["import x; import y\n"],
                id="libcst.ParserSyntaxError",
            ),
        ],
    )
    @mock.patch(MOCK % "Refactor._insert")
    @mock.patch(MOCK % "transform.rebuild_import")
    def test_transform(
        self,
        rebuild_import,
        _insert,
        rebuild_import_return,
        rebuild_import_raise,
        location,
        original_lines,
        updated_lines,
    ):
        rebuild_import.side_effect = rebuild_import_raise
        rebuild_import.return_value = rebuild_import_return
        _insert.return_value = updated_lines
        with sysu.std_redirect(sysu.STD.ERR):
            fixed_lines = self.session_maker._transform(
                location, set(), original_lines, updated_lines)
            assert fixed_lines == updated_lines

    @pytest.mark.parametrize(
        "expand_import_star_raise, name, expec_is_star",
        [
            pytest.param(None, "*", True, id="star"),
            pytest.param(None, "!*", False, id="not-star"),
            pytest.param(
                UnexpandableImportStar(Path(""), NodeLocation((1, 0), 1), ""),
                "*",
                None,
                id="not-star",
            ),
        ],
    )
    @mock.patch(MOCK % "scan.expand_import_star")
    def test_expand_import_star(
        self,
        expand_import_star,
        expand_import_star_raise,
        name,
        expec_is_star,
    ):
        node = ImportFrom(NodeLocation((1, 0), 1), [ast.alias(name=name)],
                          "xxx", 0)
        expand_import_star.return_value = node
        expand_import_star.side_effect = expand_import_star_raise
        enode, is_star = self.session_maker._expand_import_star(node)
        assert (enode, is_star) == (node, expec_is_star)

    @pytest.mark.parametrize(
        "_has_used_return, name, asname, expec_val",
        [
            pytest.param(True, "os.path.join", None, True, id="used"),
            pytest.param(False, "os.path.join", None, False, id="unused"),
            pytest.param(None, "os.path.join", "asname", False, id="as alias"),
            pytest.param(None, "os", None, False, id="single name"),
        ],
    )
    @mock.patch(MOCK % "Refactor._has_used")
    def test_is_partially_used(self, _has_used, _has_used_return, name, asname,
                               expec_val):
        _has_used.return_value = _has_used_return
        alias = ast.alias(name=name, asname=asname)
        val = self.session_maker._is_partially_used(alias, False)
        assert val == expec_val

    @pytest.mark.parametrize(
        "_has_used_return, _has_side_effects_return, all_, name, expec_val",
        [
            pytest.param(True, None, None, "not-matter", False, id="used"),
            pytest.param(
                False, None, None, "this", False, id="known side effects"),
            pytest.param(
                False, None, True, "not-matter", True, id="--all option"),
            pytest.param(False, None, False, "os", True, id="standard lib"),
            pytest.param(
                False,
                HasSideEffects.NO,
                False,
                "not-matter",
                True,
                id="no side-effects",
            ),
            pytest.param(
                False,
                HasSideEffects.YES,
                False,
                "not-matter",
                False,
                id="no all",
            ),
        ],
    )
    @mock.patch(MOCK % "Refactor._has_side_effects")
    @mock.patch(MOCK % "Refactor._has_used")
    def test_should_remove(
        self,
        _has_used,
        _has_side_effects,
        _has_used_return,
        _has_side_effects_return,
        all_,
        name,
        expec_val,
    ):
        _has_used.return_value = _has_used_return
        _has_side_effects.return_value = _has_side_effects_return
        setattr(self.configs, "all_", all_)
        alias = ast.alias(name=name, asname=None)
        node = Import(NodeLocation((1, 0), 1), [alias])
        val = self.session_maker._should_remove(node, alias, False)
        assert val == expec_val

    @pytest.mark.parametrize(
        "name, is_star, expec_val",
        [
            pytest.param("x", False, True, id="used name"),
            pytest.param("y", False, False, id="not-used name"),
            pytest.param("x.i", False, True, id="used attr"),
            pytest.param("x.j", False, False, id="not-used attr"),
            pytest.param("x", True, True, id="used name, star"),
            pytest.param("__future__", True, False, id="skip name, star"),
        ],
    )
    def test_has_used(self, name, is_star, expec_val):
        self.session_maker._source_stats = SourceStats({"x"}, {"i"},
                                                       "__future__")
        val = self.session_maker._has_used(name, is_star)
        assert val == expec_val

    @pytest.mark.parametrize(
        ("get_import_return, safe_read_return, safe_read_raise,"
         "parse_ast_return, parse_ast_raise,"
         "has_side_effects_return, has_side_effects_raise,"),
        [
            pytest.param(
                None,
                ("", "", ""),
                None,
                None,
                None,
                HasSideEffects.NOT_MODULE,
                None,
                id="no module path",
            ),
            pytest.param(
                Path(""),
                ("", "", ""),
                ReadPermissionError(13, "", Path("")),
                None,
                None,
                HasSideEffects.NOT_KNOWN,
                None,
                id="no read permission",
            ),
            pytest.param(
                Path(""),
                ("", "", ""),
                None,
                ast.Module(),
                UnparsableFile(Path(""), SyntaxError("")),
                HasSideEffects.NOT_KNOWN,
                None,
                id="Unparsable File",
            ),
            pytest.param(
                Path(""),
                ("", "", ""),
                None,
                ast.Module(),
                None,
                HasSideEffects.NOT_KNOWN,
                Exception("err"),
                id="generic err",
            ),
            pytest.param(
                Path(""),
                ("", "", ""),
                None,
                ast.Module(),
                None,
                HasSideEffects.YES,
                None,
                id="success",
            ),
        ],
    )
    @mock.patch(MOCK % "scan.SideEffectsAnalyzer.has_side_effects")
    @mock.patch(MOCK % "scan.SideEffectsAnalyzer.__init__")
    @mock.patch(MOCK % "scan.SideEffectsAnalyzer.visit")
    @mock.patch(MOCK % "scan.parse_ast")
    @mock.patch(MOCK % "iou.safe_read")
    @mock.patch(MOCK % "pathu.get_import_path")
    @mock.patch(MOCK % "pathu.get_import_from_path")
    def test_has_side_effects(
        self,
        get_import_from_path,
        get_import_path,
        safe_read,
        parse_ast,
        visit,
        init,
        has_side_effects,
        get_import_return,
        safe_read_return,
        safe_read_raise,
        parse_ast_return,
        parse_ast_raise,
        has_side_effects_return,
        has_side_effects_raise,
    ):
        init.return_value = None
        get_import_from_path.return_value = get_import_return
        get_import_path.return_value = get_import_return
        safe_read.return_value = safe_read_return
        safe_read.side_effect = safe_read_raise
        parse_ast.return_value = parse_ast_return
        parse_ast.side_effect = parse_ast_raise
        has_side_effects.return_value = has_side_effects_return
        has_side_effects.side_effect = has_side_effects_raise
        with sysu.std_redirect(sysu.STD.ERR):
            node = Import(NodeLocation((1, 0), 1), [])
            val = self.session_maker._has_side_effects("", node)
            assert val == has_side_effects_return

    @pytest.mark.parametrize(
        "rebuilt_import, updated_lines, location, expec_updated_lines",
        [
            pytest.param(
                ["import x\n"],
                [
                    "import z\n",
                    "import x, i\n",
                    "import y\n",
                ],
                NodeLocation((2, 0), 2),
                [
                    "import z\n",
                    "import x\n",
                    "import y\n",
                ],
                id="single:replace",
            ),
            pytest.param(
                "",
                [
                    "import z\n",
                    "import x, i\n",
                    "import y\n",
                ],
                NodeLocation((2, 0), 2),
                [
                    "import z\n",
                    "",
                    "import y\n",
                ],
                id="single:remove",
            ),
            pytest.param(
                [
                    "from xxx import (\n",
                    "    x\n",
                    ")\n",
                ],
                [
                    "import z\n",
                    "from xxx import (\n",
                    "    x, y\n",
                    ")\n",
                    "import y\n",
                ],
                NodeLocation((2, 0), 4),
                [
                    "import z\n",
                    "from xxx import (\n",
                    "    x\n",
                    ")\n",
                    "import y\n",
                ],
                id="multi:replace",
            ),
            pytest.param(
                [
                    "from xxx import (\n",
                    "    x\n",
                    ")\n",
                ],
                [
                    "import z\n",
                    "from xxx import (\n",
                    "    x,\n",
                    "    y\n",
                    ")\n",
                    "import y\n",
                ],
                NodeLocation((2, 0), 5),
                [
                    "import z\n",
                    "from xxx import (\n",
                    "    x\n",
                    ")\n",
                    "",
                    "import y\n",
                ],
                id="multi:remove:part",
            ),
            pytest.param(
                [""],
                [
                    "import z\n",
                    "from xxx import (\n",
                    "    x,\n",
                    "    y\n",
                    ")\n",
                    "import y\n",
                ],
                NodeLocation((2, 0), 5),
                [
                    "import z\n",
                    "",
                    "",
                    "",
                    "",
                    "import y\n",
                ],
                id="multi:remove:all",
            ),
            pytest.param(
                [
                    "from xxx import (\n",
                    "    x,\n",
                    "    y\n",
                    ")\n",
                ],
                [
                    "import z\n",
                    "from xxx import *\n",
                    "import y\n",
                ],
                NodeLocation((2, 0), 2),
                [
                    "import z\n",
                    "from xxx import (\n    x,\n    y\n)\n",
                    "import y\n",
                ],
                id="multi:add",
            ),
        ],
    )
    def test_insert(self, rebuilt_import, updated_lines, location,
                    expec_updated_lines):
        fixed = refactor.Refactor._insert(rebuilt_import, updated_lines,
                                          location)
        print(repr(fixed))
        assert fixed == expec_updated_lines
예제 #12
0
 def test_get_location(self):
     path, location = Path("file_path"), NodeLocation((2, 0), 4)
     str_location = report.Report.get_location(path, location)
     assert str_location == "file_path:2:0"
예제 #13
0
 def setup_method(self, method):
     self.configs = config.Config(paths=[Path("")])
     self.reporter = report.Report(self.configs)
     # Needed for some tests.
     self.alias = ast.alias(name="x", asname=None)
     self.impt = ImportFrom(NodeLocation((1, 0), 1), [self.alias], "xx", 1)
예제 #14
0
class TestImportTransformer:
    """`ImportTransformer` methods test case."""
    def _assert_import_equal(self, impt_stmnt: str, endlineno: int,
                             used_names: set, expec_impt: str):
        location = NodeLocation((1, 0), endlineno)
        transformer = transform.ImportTransformer(used_names, location)
        cst_tree = cst.parse_module(impt_stmnt)
        assert cst_tree.visit(transformer).code == expec_impt

    @pytest.mark.parametrize(
        "used_names, location, expec_err",
        [
            pytest.param(
                {"x", "y", "z"},
                NodeLocation((1, 4), 0),
                sysu.Pass,
                id="pass used_names",
            ),
            pytest.param(set(), None, ValueError, id="pass no used_names"),
        ],
    )
    def test_init(self, used_names, location, expec_err):
        with pytest.raises(expec_err):
            transform.ImportTransformer(used_names, location)
            raise sysu.Pass()

    @pytest.mark.parametrize(
        "impt_stmnt, endlineno, used_names, expec_impt",
        [
            pytest.param(
                "import x, y, z",
                1,
                ("x", "y", "z"),
                "import x, y, z",
                id="single, no-unused",
            ),
            pytest.param(
                "import x, y, z",
                1,
                ("x", "z"),
                "import x, z",
                id="single, some-unused",
            ),
            pytest.param(
                "import xx as x, yy as y, zz as z",
                1,
                ("xx", "yy", "zz"),
                "import xx as x, yy as y, zz as z",
                id="single, no-unused, as",
            ),
            pytest.param(
                "import xx as x, yy as y, zz as z",
                1,
                ("xx", "zz"),
                "import xx as x, zz as z",
                id="single, some-unused, as",
            ),
            pytest.param(
                ("import \\\n"
                 "    x, y, z"),
                2,
                ("x", "y", "z"),
                ("import \\\n"
                 "    x, y, z"),
                id="multi, no-unused",
            ),
            pytest.param(
                ("import \\\n"
                 "    x, y, z"),
                2,
                ("x", "z"),
                ("import \\\n"
                 "    x, z"),
                id="multi, some-unused",
            ),
            pytest.param(
                ("import \\\n"
                 "    xx as x, yy as y, zz as z"),
                2,
                ("xx", "yy", "zz"),
                ("import \\\n"
                 "    xx as x, yy as y, zz as z"),
                id="multi, no-unused, as",
            ),
            pytest.param(
                ("import \\\n"
                 "    xx as x, yy as y, zz as z"),
                2,
                ("xx", "zz"),
                ("import \\\n"
                 "    xx as x, zz as z"),
                id="multi, some-unused, as",
            ),
        ],
    )
    def test_leave_Import(self, impt_stmnt, endlineno, used_names, expec_impt):
        #: `leave_Import` returns `refactor_import`.
        #: so if there's a bug, please debug `refactor_import`.
        self._assert_import_equal(impt_stmnt, endlineno, used_names,
                                  expec_impt)

    @pytest.mark.parametrize(
        "impt_stmnt, endlineno, used_names, expec_impt",
        [
            pytest.param(
                "from x import *",
                1,
                ("x", "y", "z"),
                "from x import x, y, z",
                id="single, star",
            ),
            pytest.param(
                "from x import *",
                1,
                ("x", "y", "z", "i", "j"),
                ("from x import (\n"
                 "    x,\n"
                 "    y,\n"
                 "    z,\n"
                 "    i,\n"
                 "    j\n"
                 ")"),
                id="multi, star",
            ),
            pytest.param(
                "from x import *",
                1,
                ("x", "y", "z", "i", "j.j"),
                ("from x import (\n"
                 "    x,\n"
                 "    y,\n"
                 "    z,\n"
                 "    i\n"
                 ")"),
                id="multi, star, names collision",
            ),
            pytest.param(
                "from xxx import x, y, z",
                1,
                ("x", "y", "z"),
                "from xxx import x, y, z",
                id="single, no-unused",
            ),
            pytest.param(
                "from xxx import x, y, z",
                1,
                ("x", "z"),
                "from xxx import x, z",
                id="single, some-unused",
            ),
            pytest.param(
                "from xxx import (x, y, z)",
                1,
                ("x", "y", "z"),
                "from xxx import (x, y, z)",
                id="single, parentheses, no-end-comma, no-unused",
            ),
            pytest.param(
                "from xxx import (x, y, z)",
                1,
                ("x", "z"),
                "from xxx import (x, z)",
                id="single, parentheses, no-end-comma, some-unused",
            ),
            pytest.param(
                "from xxx import (x, y, z)",
                1,
                ("x", "y", "z"),
                "from xxx import (x, y, z)",
                id="single, parentheses, end-comma, no-unused",
            ),
            pytest.param(
                "from xxx import (x, y, z,)",
                1,
                ("x", "z"),
                "from xxx import (x, z,)",
                id="single, parentheses, nend-comma, some-unused",
            ),
            pytest.param(
                "from xxx import xx as x, yy as y, zz as z",
                1,
                ("xx", "yy", "zz"),
                "from xxx import xx as x, yy as y, zz as z",
                id="single, no-unused, as",
            ),
            pytest.param(
                "from xxx import xx as x, yy as y, zz as z",
                1,
                ("xx", "zz"),
                "from xxx import xx as x, zz as z",
                id="single, some-unused, as",
            ),
            pytest.param(
                ("from xxx import x,\\\n"
                 "    y, \\\n"
                 "    z"),
                3,
                ("x", "y", "z"),
                ("from xxx import x,\\\n"
                 "    y, \\\n"
                 "    z"),
                id="multi, slash, no-unused",
            ),
            pytest.param(
                ("from xxx import x,\\\n"
                 "    y, \\\n"
                 "    z"),
                3,
                ("x", "z"),
                ("from xxx import x,\\\n"
                 "    z"),
                id="multi, slash, some-unused",
            ),
            pytest.param(
                ("from xxx import x,\\\n"
                 "    y, z"),
                2,
                ("x", "z"),
                ("from xxx import x,\\\n"
                 "    z"),
                id="multi, slash, double, some-unused",
            ),
            pytest.param(
                ("from xxx import xx as x,\\\n"
                 "    yy as y, \\\n"
                 "    zz as z"),
                3,
                ("xx", "yy", "zz"),
                ("from xxx import xx as x,\\\n"
                 "    yy as y, \\\n"
                 "    zz as z"),
                id="multi, slash, no-unused, as",
            ),
            pytest.param(
                ("from xxx import xx as x,\\\n"
                 "    yy as y, \\\n"
                 "    zz as z"),
                3,
                ("xx", "zz"),
                ("from xxx import xx as x,\\\n"
                 "    zz as z"),
                id="multi, slash, some-unused, as",
            ),
            pytest.param(
                ("from xxx import xx as x,\\\n"
                 "    yy as y, zz as z"),
                2,
                ("xx", "zz"),
                ("from xxx import xx as x,\\\n"
                 "    zz as z"),
                id="multi, slash, double, some-unused, as",
            ),
            pytest.param(
                ("from xxx import (\n"
                 "    x,\n"
                 "    y,\n"
                 "    z\n"
                 ")"),
                5,
                ("x", "y", "z"),
                ("from xxx import (\n"
                 "    x,\n"
                 "    y,\n"
                 "    z\n"
                 ")"),
                id="multi, parentheses, no-end-comma, no-unused",
            ),
            pytest.param(
                ("from xxx import (\n"
                 "    x,\n"
                 "    y,\n"
                 "    z,\n"
                 ")"),
                5,
                ("x", "y", "z"),
                ("from xxx import (\n"
                 "    x,\n"
                 "    y,\n"
                 "    z,\n"
                 ")"),
                id="multi, parentheses, end-comma, no-unused",
            ),
            pytest.param(
                ("from xxx import (\n"
                 "    x,\n"
                 "    y,\n"
                 "    z\n"
                 ")"),
                5,
                ("x", "z"),
                ("from xxx import (\n"
                 "    x,\n"
                 "    z\n"
                 ")"),
                id="multi, parentheses, no-end-comma, some-unused",
            ),
            pytest.param(
                ("from xxx import (\n"
                 "    x,\n"
                 "    y,\n"
                 "    z,\n"
                 ")"),
                5,
                ("x", "z"),
                ("from xxx import (\n"
                 "    x,\n"
                 "    z,\n"
                 ")"),
                id="multi, parentheses, end-comma, some-unused",
            ),
            pytest.param(
                ("from xxx import (\n"
                 "    x,\n"
                 "    y, z\n"
                 ")"),
                4,
                ("x", "z"),
                ("from xxx import (\n"
                 "    x,\n"
                 "    z\n"
                 ")"),
                id="multi, parentheses, double, no-end-comma, some-unused",
            ),
            pytest.param(
                ("from xxx import (\n"
                 "    x,\n"
                 "    y, z,\n"
                 ")"),
                4,
                ("x", "z"),
                ("from xxx import (\n"
                 "    x,\n"
                 "    z,\n"
                 ")"),
                id="multi, parentheses, double, end-comma, some-unused",
            ),
            pytest.param(
                ("from xxx import (\n"
                 "    xx as x,\n"
                 "    yy as y,\n"
                 "    zz as z\n"
                 ")"),
                5,
                ("xx", "yy", "zz"),
                ("from xxx import (\n"
                 "    xx as x,\n"
                 "    yy as y,\n"
                 "    zz as z\n"
                 ")"),
                id="multi, parentheses, no-end-comma, no-unused, as",
            ),
            pytest.param(
                ("from xxx import (\n"
                 "    xx as x,\n"
                 "    yy as y,\n"
                 "    zz as z,\n"
                 ")"),
                5,
                ("xx", "yy", "zz"),
                ("from xxx import (\n"
                 "    xx as x,\n"
                 "    yy as y,\n"
                 "    zz as z,\n"
                 ")"),
                id="multi, parentheses, end-comma, no-unused, as",
            ),
            pytest.param(
                ("from xxx import (\n"
                 "    xx as x,\n"
                 "    yy as y,\n"
                 "    zz as z\n"
                 ")"),
                5,
                ("xx", "zz"),
                ("from xxx import (\n"
                 "    xx as x,\n"
                 "    zz as z\n"
                 ")"),
                id="multi, parentheses, no-end-comma, some-unused, as",
            ),
            pytest.param(
                ("from xxx import (\n"
                 "    xx as x,\n"
                 "    yy as y,\n"
                 "    zz as z,\n"
                 ")"),
                5,
                ("xx", "zz"),
                ("from xxx import (\n"
                 "    xx as x,\n"
                 "    zz as z,\n"
                 ")"),
                id="multi, parentheses, end-comma, some-unused, as",
            ),
            pytest.param(
                ("from xxx import (\n"
                 "    xx as x,\n"
                 "    yy as y, zz as z\n"
                 ")"),
                4,
                ("xx", "zz"),
                ("from xxx import (\n"
                 "    xx as x,\n"
                 "    zz as z\n"
                 ")"),
                id="multi, parentheses, double, no-end-comma, some-unused, as",
            ),
            pytest.param(
                ("from xxx import (\n"
                 "    xx as x,\n"
                 "    yy as y, zz as z,\n"
                 ")"),
                4,
                ("xx", "zz"),
                ("from xxx import (\n"
                 "    xx as x,\n"
                 "    zz as z,\n"
                 ")"),
                id="multi, parentheses, double, end-comma, some-unused, as",
            ),
            pytest.param(
                ("from xxx import (x,\n"
                 "    y,\n"
                 "    z,)"),
                3,
                ("x", "z"),
                ("from xxx import (x,\n"
                 "    z,)"),
                id="multi, parentheses(no-new-lines), end-comma, some-unused",
            ),
        ],
    )
    def test_leave_ImportFrom(self, impt_stmnt, endlineno, used_names,
                              expec_impt):
        #: `leave_ImportFrom` returns:
        #:      - `refactor_import_star` when * import passed.
        #:      - otherwise `refactor_import`.
        #: Debug `refactor_import_star` or `refactor_import`.
        self._assert_import_equal(impt_stmnt, endlineno, used_names,
                                  expec_impt)

    @pytest.mark.parametrize("name", ["x", "x.y", "x.y.z"])
    @mock.patch(MOCK % "ImportTransformer.__init__")
    def test_get_alias_name(self, init, name):
        init.return_value = None

        def get_name_node(name: str) -> Union[cst.Name, cst.Attribute]:
            # Inverse `_get_alias_name`.
            if "." not in name:
                return cst.Name(name)
            names = name.split(".")
            value = get_name_node(".".join(names[:-1]))
            attr = get_name_node(names[-1])
            return cst.Attribute(value=value, attr=attr)  # type: ignore

        node = get_name_node(name)
        transformer = transform.ImportTransformer(None, None)
        assert transformer._get_alias_name(node) == name

    @pytest.mark.parametrize("indent", [" " * 0, " " * 2, " " * 4, " " * 8])
    @mock.patch(MOCK % "ImportTransformer.__init__")
    def test_multiline_parenthesized_whitespace(self, init, indent):
        init.return_value = None
        transformer = transform.ImportTransformer(None, None)
        mpw = transformer._multiline_parenthesized_whitespace(indent)
        assert mpw.last_line.value == indent

    @pytest.mark.parametrize("indent", [" " * 0, " " * 2, " " * 4, " " * 8])
    @mock.patch(MOCK % "ImportTransformer.__init__")
    def test_multiline_alias(self, init, indent):
        init.return_value = None
        transformer = transform.ImportTransformer(None, None)
        transformer._indentation = indent
        alias = transformer._multiline_alias(
            cst.ImportAlias(name=cst.Name("x")))
        assert alias.comma.whitespace_after.last_line.value == indent + " " * 4

    @pytest.mark.parametrize("indent", [" " * 0, " " * 2, " " * 4, " " * 8])
    @mock.patch(MOCK % "ImportTransformer.__init__")
    def test_multiline_lpar(self, init, indent):
        init.return_value = None
        transformer = transform.ImportTransformer(None, None)
        transformer._indentation = indent
        lpar = transformer._multiline_lpar()
        assert lpar.whitespace_after.last_line.value == indent + " " * 4

    @pytest.mark.parametrize("indent", [" " * 0, " " * 2, " " * 4, " " * 8])
    @mock.patch(MOCK % "ImportTransformer.__init__")
    def test_multiline_rpar(self, init, indent):
        init.return_value = None
        transformer = transform.ImportTransformer(None, None)
        transformer._indentation = indent
        rpar = transformer._multiline_rpar()
        assert rpar.whitespace_before.last_line.value == indent

    @pytest.mark.parametrize(
        "code, endlineno, ismultiline",
        [
            pytest.param("import x, y, z", 1, False, id="single, import"),
            pytest.param(
                ("import \\\n"
                 "    x, y, z"), 2, True, id="multi, import"),
            pytest.param("from xxx import (x, y, z,)",
                         1,
                         False,
                         id="single, from, parentheses"),
            pytest.param(
                ("from xxx import (\n"
                 "    x,\n"
                 "    y,\n"
                 "    z,\n"
                 ")"),
                5,
                True,
                id="multi, from, parentheses",
            ),
            pytest.param(
                ("from xxx import (\n"
                 "    x,\n"
                 "    y,\n"
                 "    z,)"),
                4,
                True,
                id="multi, from, parentheses, distorted-end",
            ),
            pytest.param(
                ("from xxx import x,\\\n"
                 "    y,\\\n"
                 "    z"),
                3,
                True,
                id="multi, from, slash",
            ),
        ],
    )
    def test_stylize(self, code, endlineno, ismultiline):
        location = NodeLocation((1, 0), endlineno)
        node = cst.parse_module(code).body[0].body[0]
        transformer = transform.ImportTransformer({""}, location)
        new_node = transformer._stylize(node, node.names, False)
        if getattr(new_node, "rpar", None) and ismultiline:
            assert new_node.rpar != node.rpar
            assert new_node.lpar != node.lpar
        assert new_node.names[-1].comma == cst.MaybeSentinel.DEFAULT