예제 #1
0
def test_source_strips():
    source = Source("")
    assert source == Source()
    assert str(source) == ''
    assert source.strip() == source
예제 #2
0
def test_source_strip_multiline():
    source = Source()
    source.lines = ["", " hello", "  "]
    source2 = source.strip()
    assert source2.lines == [" hello"]
예제 #3
0
 def test_getstatementrange_with_syntaxerror_issue7(self):
     source = Source(":")
     pytest.raises(SyntaxError, lambda: source.getstatementrange(0))
예제 #4
0
 def test_compile_to_ast(self):
     import ast
     source = Source("x = 4")
     mod = source.compile(flag=ast.PyCF_ONLY_AST)
     assert isinstance(mod, ast.Module)
     compile(mod, "<filename>", "exec")
예제 #5
0
class TestSourceParsingAndCompiling(object):
    source = Source("""\
        def f(x):
            assert (x ==
                    3 +
                    4)
    """).strip()

    def test_compile(self):
        co = _pytest._code.compile("x=3")
        d = {}
        exec(co, d)
        assert d['x'] == 3

    def test_compile_and_getsource_simple(self):
        co = _pytest._code.compile("x=3")
        exec(co)
        source = _pytest._code.Source(co)
        assert str(source) == "x=3"

    def test_compile_and_getsource_through_same_function(self):
        def gensource(source):
            return _pytest._code.compile(source)

        co1 = gensource("""
            def f():
                raise KeyError()
        """)
        co2 = gensource("""
            def f():
                raise ValueError()
        """)
        source1 = py.std.inspect.getsource(co1)
        assert 'KeyError' in source1
        source2 = py.std.inspect.getsource(co2)
        assert 'ValueError' in source2

    def test_getstatement(self):
        # print str(self.source)
        ass = str(self.source[1:])
        for i in range(1, 4):
            # print "trying start in line %r" % self.source[i]
            s = self.source.getstatement(i)
            #x = s.deindent()
            assert str(s) == ass

    def test_getstatementrange_triple_quoted(self):
        # print str(self.source)
        source = Source("""hello('''
        ''')""")
        s = source.getstatement(0)
        assert s == str(source)
        s = source.getstatement(1)
        assert s == str(source)

    @astonly
    def test_getstatementrange_within_constructs(self):
        source = Source("""\
            try:
                try:
                    raise ValueError
                except SomeThing:
                    pass
            finally:
                42
        """)
        assert len(source) == 7
        # check all lineno's that could occur in a traceback
        # assert source.getstatementrange(0) == (0, 7)
        # assert source.getstatementrange(1) == (1, 5)
        assert source.getstatementrange(2) == (2, 3)
        assert source.getstatementrange(3) == (3, 4)
        assert source.getstatementrange(4) == (4, 5)
        # assert source.getstatementrange(5) == (0, 7)
        assert source.getstatementrange(6) == (6, 7)

    def test_getstatementrange_bug(self):
        source = Source("""\
            try:
                x = (
                   y +
                   z)
            except:
                pass
        """)
        assert len(source) == 6
        assert source.getstatementrange(2) == (1, 4)

    def test_getstatementrange_bug2(self):
        source = Source("""\
            assert (
                33
                ==
                [
                  X(3,
                      b=1, c=2
                   ),
                ]
              )
        """)
        assert len(source) == 9
        assert source.getstatementrange(5) == (0, 9)

    def test_getstatementrange_ast_issue58(self):
        source = Source("""\

            def test_some():
                for a in [a for a in
                    CAUSE_ERROR]: pass

            x = 3
        """)
        assert getstatement(2, source).lines == source.lines[2:3]
        assert getstatement(3, source).lines == source.lines[3:4]

    @pytest.mark.skipif("sys.version_info < (2,6)")
    def test_getstatementrange_out_of_bounds_py3(self):
        source = Source("if xxx:\n   from .collections import something")
        r = source.getstatementrange(1)
        assert r == (1, 2)

    def test_getstatementrange_with_syntaxerror_issue7(self):
        source = Source(":")
        pytest.raises(SyntaxError, lambda: source.getstatementrange(0))

    @pytest.mark.skipif("sys.version_info < (2,6)")
    def test_compile_to_ast(self):
        import ast
        source = Source("x = 4")
        mod = source.compile(flag=ast.PyCF_ONLY_AST)
        assert isinstance(mod, ast.Module)
        compile(mod, "<filename>", "exec")

    def test_compile_and_getsource(self):
        co = self.source.compile()
        py.builtin.exec_(co, globals())
        f(7)
        excinfo = pytest.raises(AssertionError, "f(6)")
        frame = excinfo.traceback[-1].frame
        stmt = frame.code.fullsource.getstatement(frame.lineno)
        # print "block", str(block)
        assert str(stmt).strip().startswith('assert')

    @pytest.mark.parametrize('name', ['', None, 'my'])
    def test_compilefuncs_and_path_sanity(self, name):
        def check(comp, name):
            co = comp(self.source, name)
            if not name:
                expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2)
            else:
                expected = "codegen %r %s:%d>" % (name, mypath,
                                                  mylineno + 2 + 2)
            fn = co.co_filename
            assert fn.endswith(expected)

        mycode = _pytest._code.Code(self.test_compilefuncs_and_path_sanity)
        mylineno = mycode.firstlineno
        mypath = mycode.path

        for comp in _pytest._code.compile, _pytest._code.Source.compile:
            check(comp, name)

    def test_offsetless_synerr(self):
        pytest.raises(SyntaxError,
                      _pytest._code.compile,
                      "lambda a,a: 0",
                      mode='eval')
예제 #6
0
 def test_getstatementrange_out_of_bounds_py3(self):
     source = Source("if xxx:\n   from .collections import something")
     r = source.getstatementrange(1)
     assert r == (1, 2)
예제 #7
0
def test_unicode():
    x = Source("4")
    assert str(x) == "4"
    co = _pytest._code.compile('"å"', mode="eval")
    val = eval(co)
    assert isinstance(val, str)
예제 #8
0
 def _getlines(self, lines2):
     if isinstance(lines2, str):
         lines2 = Source(lines2)
     if isinstance(lines2, Source):
         lines2 = lines2.strip().lines
     return lines2
예제 #9
0
def test_source_fallback() -> None:
    src = Source(x)
    expected = """def x():
    pass"""
    assert str(src) == expected
예제 #10
0
def test_source_strips() -> None:
    source = Source("")
    assert source == Source()
    assert str(source) == ""
    assert source.strip() == source
예제 #11
0
def test_unicode():
    x = Source(u"4")
    assert str(x) == "4"
    co = _pytest._code.compile(u'u"å"', mode="eval")
    val = eval(co)
    assert isinstance(val, six.text_type)
예제 #12
0
def test_source_from_inner_function() -> None:
    def f():
        raise NotImplementedError()

    source = Source(f)
    assert str(source).startswith("def f():")
예제 #13
0
def test_source_from_lines() -> None:
    lines = ["a \n", "b\n", "c"]
    source = Source(lines)
    assert source.lines == ["a ", "b", "c"]
예제 #14
0
def getstatement(lineno: int, source) -> Source:
    from _pytest._code.source import getstatementrange_ast

    src = Source(source)
    ast, start, end = getstatementrange_ast(lineno, src)
    return src[start:end]
예제 #15
0
def test_source_from_function() -> None:
    source = Source(test_source_str_function)
    assert str(source).startswith("def test_source_str_function() -> None:")