예제 #1
0
    def test_compile_to_ast(self):
        import ast

        source = Source("x = 4")
        mod = source.compile(flag=ast.PyCF_ONLY_AST)
        assert isinstance(mod, ast.Module)
        compile(mod, "<filename>", "exec")
예제 #2
0
class TestSourceParsingAndCompiling:
    def setup_class(self) -> None:
        self.source = Source("""\
            def f(x):
                assert (x ==
                        3 +
                        4)
        """).strip()

    def test_compile(self) -> None:
        co = _pytest._code.compile("x=3")
        d = {}  # type: Dict[str, Any]
        exec(co, d)
        assert d["x"] == 3

    def test_compile_and_getsource_simple(self) -> None:
        co = _pytest._code.compile("x=3")
        exec(co)
        source = _pytest._code.Source(co)
        assert str(source) == "x=3"

    def test_compile_and_getsource_through_same_function(self) -> None:
        def gensource(source):
            return _pytest._code.compile(source)

        co1 = gensource("""
            def f():
                raise KeyError()
        """)
        co2 = gensource("""
            def f():
                raise ValueError()
        """)
        source1 = inspect.getsource(co1)
        assert "KeyError" in source1
        source2 = inspect.getsource(co2)
        assert "ValueError" in source2

    def test_getstatement(self) -> None:
        # print str(self.source)
        ass = str(self.source[1:])
        for i in range(1, 4):
            # print "trying start in line %r" % self.source[i]
            s = self.source.getstatement(i)
            # x = s.deindent()
            assert str(s) == ass

    def test_getstatementrange_triple_quoted(self) -> None:
        # print str(self.source)
        source = Source("""hello('''
        ''')""")
        s = source.getstatement(0)
        assert s == str(source)
        s = source.getstatement(1)
        assert s == str(source)

    def test_getstatementrange_within_constructs(self) -> None:
        source = Source("""\
            try:
                try:
                    raise ValueError
                except SomeThing:
                    pass
            finally:
                42
        """)
        assert len(source) == 7
        # check all lineno's that could occur in a traceback
        # assert source.getstatementrange(0) == (0, 7)
        # assert source.getstatementrange(1) == (1, 5)
        assert source.getstatementrange(2) == (2, 3)
        assert source.getstatementrange(3) == (3, 4)
        assert source.getstatementrange(4) == (4, 5)
        # assert source.getstatementrange(5) == (0, 7)
        assert source.getstatementrange(6) == (6, 7)

    def test_getstatementrange_bug(self) -> None:
        source = Source("""\
            try:
                x = (
                   y +
                   z)
            except:
                pass
        """)
        assert len(source) == 6
        assert source.getstatementrange(2) == (1, 4)

    def test_getstatementrange_bug2(self) -> None:
        source = Source("""\
            assert (
                33
                ==
                [
                  X(3,
                      b=1, c=2
                   ),
                ]
              )
        """)
        assert len(source) == 9
        assert source.getstatementrange(5) == (0, 9)

    def test_getstatementrange_ast_issue58(self) -> None:
        source = Source("""\

            def test_some():
                for a in [a for a in
                    CAUSE_ERROR]: pass

            x = 3
        """)
        assert getstatement(2, source).lines == source.lines[2:3]
        assert getstatement(3, source).lines == source.lines[3:4]

    def test_getstatementrange_out_of_bounds_py3(self) -> None:
        source = Source("if xxx:\n   from .collections import something")
        r = source.getstatementrange(1)
        assert r == (1, 2)

    def test_getstatementrange_with_syntaxerror_issue7(self) -> None:
        source = Source(":")
        pytest.raises(SyntaxError, lambda: source.getstatementrange(0))

    def test_compile_to_ast(self) -> None:
        source = Source("x = 4")
        mod = source.compile(flag=ast.PyCF_ONLY_AST)
        assert isinstance(mod, ast.Module)
        compile(mod, "<filename>", "exec")

    def test_compile_and_getsource(self) -> None:
        co = self.source.compile()
        exec(co, globals())
        f(7)  # type: ignore
        excinfo = pytest.raises(AssertionError, f, 6)  # type: ignore
        assert excinfo is not None
        frame = excinfo.traceback[-1].frame
        assert isinstance(frame.code.fullsource, Source)
        stmt = frame.code.fullsource.getstatement(frame.lineno)
        assert str(stmt).strip().startswith("assert")

    @pytest.mark.parametrize("name", ["", None, "my"])
    def test_compilefuncs_and_path_sanity(self, name: Optional[str]) -> None:
        def check(comp, name) -> None:
            co = comp(self.source, name)
            if not name:
                expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2
                                               )  # type: ignore
            else:
                expected = "codegen %r %s:%d>" % (
                    name, mypath, mylineno + 2 + 2)  # type: ignore
            fn = co.co_filename
            assert fn.endswith(expected)

        mycode = _pytest._code.Code(self.test_compilefuncs_and_path_sanity)
        mylineno = mycode.firstlineno
        mypath = mycode.path

        for comp in _pytest._code.compile, _pytest._code.Source.compile:
            check(comp, name)

    def test_offsetless_synerr(self):
        pytest.raises(SyntaxError,
                      _pytest._code.compile,
                      "lambda a,a: 0",
                      mode="eval")