Beispiel #1
0
def test_imbalanced(strict, a, b, c, d):
    src = imbalanced_template.format(a, b, c, d)
    tree = cst.parse_statement(src)
    env = SymbolTable({}, {})
    imbalanced = exec_def_in_file(tree, env)
    can_name_error = False
    for x in (False, True):
        for y in (False, True):
            try:
                imbalanced(x, y)
            except NameError:
                can_name_error = True
                break

    if can_name_error and strict:
        with pytest.raises(SyntaxError):
            ssa_imbalanced = apply_passes([ssa(strict)])(imbalanced)
    else:
        ssa_imbalanced = apply_passes([ssa(strict)])(imbalanced)
        for x in (False, True):
            for y in (False, True):
                try:
                    assert imbalanced(x, y) == ssa_imbalanced(x, y)
                except NameError:
                    assert can_name_error
Beispiel #2
0
def test_nstrict():
    # This function would confuse strict ssa in so many ways
    def f1(cond):
        if cond:
            if cond:
                return 0
        elif not cond:
            z = 1

        if not cond:
            x = z
            return x

    f2 = apply_passes([ssa(False)])(f1)
    assert inspect.getsource(f2) == '''\
def f1(cond):
    _cond_2 = cond
    _cond_0 = cond
    __0_return_0 = 0
    _cond_1 = not cond
    z_0 = 1
    _cond_3 = not cond
    x_0 = z_0
    __0_return_1 = x_0
    return __0_return_0 if _cond_2 and _cond_0 else __0_return_1
'''
    for cond in [True, False]:
        assert f1(cond) == f2(cond)
Beispiel #3
0
def test_attrs_returns(strict):
    def f1(t, cond1, cond2):
        if cond1:
            t.x = 1
            if cond2:
                return 0
        else:
            t.x = 0
            if cond2:
                return 1
        return -1

    f2  = apply_passes([ssa(strict)])(f1)

    t1 = Thing()
    t2 = Thing()
    assert t1 == t2

    for _ in range(NTEST):
        c1 = random.randint(0, 1)
        c2 = random.randint(0, 1)
        o1 = f1(t1, c1, c2)
        o2 = f2(t2, c1, c2)
        assert o1 == o2
        assert t1 == t2
Beispiel #4
0
def test_call():
    def f1(x):
        x = 2
        return g(x=x)

    f2 = apply_passes([ssa(False)])(f1)
    assert inspect.getsource(f2) == '''\
Beispiel #5
0
def test_call_in_annotations(strict, x, y):
    r_x = x if x else 'int'
    r_y = y if y else 'int'
    x = f': {x}' if x else x
    y = f': {y}' if y else y
    src = sig_template.format(x, y, r_x, r_y)
    tree = cst.parse_statement(src)
    env = SymbolTable(locals(), globals())
    f1 = exec_def_in_file(tree, env)
    f2 = apply_passes([ssa(strict)])(f1)
Beispiel #6
0
def test_basic(phi_args, expected_name):
    def basic(s):
        return 0 if s else 1

    phi_basic = apply_passes([if_to_phi(*phi_args)])(basic)

    for s in (True, False):
        assert basic(s) == phi_basic(s)

    assert inspect.getsource(phi_basic) == f'''\
Beispiel #7
0
def test_nested(phi_args, expected_name):
    def nested(s, t):
        return 0 if s else 1 if t else 2

    phi_nested = apply_passes([if_to_phi(*phi_args)])(nested)

    for s in (True, False):
        for t in (True, False):
            assert nested(s, t) == phi_nested(s, t)

    assert inspect.getsource(phi_nested) == f'''\
Beispiel #8
0
def test_basic_if(strict, a, b):
    if a == b == 'return':
        final = ''
    else:
        final = 'return r'

    src = basic_template.format(a, b, final)
    tree = cst.parse_statement(src)
    env = SymbolTable({}, {})
    basic = exec_def_in_file(tree, env)
    ssa_basic = apply_passes([ssa(strict)])(basic)

    for x in (False, True):
        assert basic(x) == ssa_basic(x)
Beispiel #9
0
def test_basic(cond):
    def basic():
        if inline(cond):
            return 0
        else:
            return 1

    inlined = apply_passes([if_inline()])(basic)
    inlined_src = inspect.getsource(inlined)
    assert inlined_src == f'''\
def basic():
    return {0 if cond else 1}
'''
    assert basic() == inlined()
Beispiel #10
0
def test_attr():
    bar = namedtuple('bar', ['x', 'y'])

    def f1(x, y):
        z = bar(1, 0)
        if x:
            a = z
        else:
            a = y
        a.x = 3
        return a

    f2 = apply_passes([ssa(False)])(f1)
    assert inspect.getsource(f2) == '''\
Beispiel #11
0
def test_nested(strict, a, b, c, d):
    if a == b == c == d == 'return':
        final = ''
    else:
        final = 'return r'

    src = nested_template.format(a, b, c, d, final)
    tree = cst.parse_statement(src)
    env = SymbolTable({}, {})
    nested = exec_def_in_file(tree, env)
    ssa_nested = apply_passes([ssa(strict)])(nested)

    for x in (False, True):
        for y in (False, True):
            assert nested(x, y) == ssa_nested(x, y)
Beispiel #12
0
def test_double_nested_function_call():
    def bar(x):
        return x

    def baz(x):
        return x + 1

    deco = apply_passes([ssa()])
    deco.env.locals['deco'] = deco

    @deco               # 1
    def foo(a, b, c):   # 2
        if b:           # 3
            a = bar(a)  # 4
        else:           # 5
            a = bar(a)  # 6
        if c:           # 7
            b = bar(b)  # 8
        else:           # 9
            b = bar(b)  # 10
        return a, b     # 11
    assert inspect.getsource(foo) == '''\
def foo(a, b, c):   # 2
    _cond_0 = b
    a_0 = bar(a)  # 4
    a_1 = bar(a)  # 6
    a_2 = a_0 if _cond_0 else a_1
    _cond_1 = c
    b_0 = bar(b)  # 8
    b_1 = bar(b)  # 10
    b_2 = b_0 if _cond_1 else b_1
    __0_return_0 = a_2, b_2     # 11
    return __0_return_0
'''

    symbol_tables = deco.metadata['SYMBOL-TABLE']
    assert len(symbol_tables) == 1
    assert symbol_tables[0][0] == ssa
    symbol_table = symbol_tables[0][1]
    gold_table = {i: {
        'a': 'a' if i < 4 else 'a_0' if i <  6 else 'a_1' if i <  7 else 'a_2',
        'b': 'b' if i < 8 else 'b_0' if i < 10 else 'b_1' if i < 11 else 'b_2',
        'c': 'c',
        } for i in range(2, 12)}
    assert symbol_table == gold_table
Beispiel #13
0
def test_nested(cond_0, cond_1):
    def nested():
        if inline(cond_0):
            if inline(cond_1):
                return 3
            else:
                return 2
        else:
            if inline(cond_1):
                return 1
            else:
                return 0

    inlined = apply_passes([if_inline()])(nested)
    assert inspect.getsource(inlined) == f'''\
def nested():
    return {nested()}
'''
    assert nested() == inlined()
Beispiel #14
0
def test_reassign_arg():
    def bar(x):
        return x

    deco = apply_passes([ssa()])
    deco.env.locals['deco'] = deco

    @deco
    def foo(a, b):
        if b:
            a = len(a)
        return a
    assert inspect.getsource(foo) == '''\
def foo(a, b):
    _cond_0 = b
    a_0 = len(a)
    a_1 = a_0 if _cond_0 else a
    __0_return_0 = a_1
    return __0_return_0
'''
    symbol_tables = deco.metadata['SYMBOL-TABLE']
    assert len(symbol_tables) == 1
    assert symbol_tables[0][0] == ssa
    symbol_table = symbol_tables[0][1]
    assert symbol_table == {
        2: {
            'a': 'a',
            'b': 'b',
        },
        3: {
            'a': 'a',
            'b': 'b',
        },
        4: {
            'a': 'a_0',
            'b': 'b',
        },
        5: {
            'a': 'a_1',
            'b': 'b',
        },
    }
Beispiel #15
0
def test_outer_inline(cond_0, cond_1):
    def nested(cond):
        if inline(cond_0):
            if cond:
                return 3
            else:
                return 2
        else:
            if cond:
                return 1
            else:
                return 0

    inlined = apply_passes([if_inline()])(nested)
    assert inspect.getsource(inlined) == f'''\
def nested(cond):
    if cond:
        return {3 if cond_0 else 1}
    else:
        return {2 if cond_0 else 0}
'''
    assert nested(cond_1) == inlined(cond_1)
Beispiel #16
0
def test_attrs_basic(strict):
    def f1(t, cond):
        old = t.x
        if cond:
            t.x = 1
        else:
            t.x = 0
        return old

    f2  = apply_passes([ssa(strict)])(f1)

    t1 = Thing()
    t2 = Thing()

    assert t1 == t2 == None
    f1(t1, True)
    assert t1 != t2
    f2(t2, True)
    assert t1 == t2 == 1
    f1(t1, False)
    assert t1 != t2
    f2(t2, False)
    assert t1 == t2 == 0
Beispiel #17
0
 class Counter2:
     __init__ = Counter1.__init__
     __call__ = apply_passes([ssa(strict)])(Counter1.__call__)
     get_step = Counter1.get_step