예제 #1
0
def test_imbalanced(strict, a, b, c, d):
    src = imbalanced_template.format(a, b, c, d)
    tree = cst.parse_statement(src)
    env = SymbolTable({}, {})
    imbalanced = exec_def_in_file(tree, env)
    can_name_error = False
    for x in (False, True):
        for y in (False, True):
            try:
                imbalanced(x, y)
            except NameError:
                can_name_error = True
                break

    if can_name_error and strict:
        with pytest.raises(SyntaxError):
            ssa_imbalanced = apply_passes([ssa(strict)])(imbalanced)
    else:
        ssa_imbalanced = apply_passes([ssa(strict)])(imbalanced)
        for x in (False, True):
            for y in (False, True):
                try:
                    assert imbalanced(x, y) == ssa_imbalanced(x, y)
                except NameError:
                    assert can_name_error
예제 #2
0
def test_call():
    def f1(x):
        x = 2
        return g(x=x)

    f2 = apply_passes([ssa(False)])(f1)
    assert inspect.getsource(f2) == '''\
예제 #3
0
def test_nstrict():
    # This function would confuse strict ssa in so many ways
    def f1(cond):
        if cond:
            if cond:
                return 0
        elif not cond:
            z = 1

        if not cond:
            x = z
            return x

    f2 = apply_passes([ssa(False)])(f1)
    assert inspect.getsource(f2) == '''\
def f1(cond):
    _cond_2 = cond
    _cond_0 = cond
    __0_return_0 = 0
    _cond_1 = not cond
    z_0 = 1
    _cond_3 = not cond
    x_0 = z_0
    __0_return_1 = x_0
    return __0_return_0 if _cond_2 and _cond_0 else __0_return_1
'''
    for cond in [True, False]:
        assert f1(cond) == f2(cond)
예제 #4
0
def test_attrs_returns(strict):
    def f1(t, cond1, cond2):
        if cond1:
            t.x = 1
            if cond2:
                return 0
        else:
            t.x = 0
            if cond2:
                return 1
        return -1

    f2  = apply_passes([ssa(strict)])(f1)

    t1 = Thing()
    t2 = Thing()
    assert t1 == t2

    for _ in range(NTEST):
        c1 = random.randint(0, 1)
        c2 = random.randint(0, 1)
        o1 = f1(t1, c1, c2)
        o2 = f2(t2, c1, c2)
        assert o1 == o2
        assert t1 == t2
예제 #5
0
def _do_ssa(func, strict, **kwargs):
    for dec in (
            begin_rewrite(),
            debug(**kwargs),
            ssa(strict),
            debug(**kwargs),
            end_rewrite()):
        func = dec(func)
    return func
예제 #6
0
def test_call_in_annotations(strict, x, y):
    r_x = x if x else 'int'
    r_y = y if y else 'int'
    x = f': {x}' if x else x
    y = f': {y}' if y else y
    src = sig_template.format(x, y, r_x, r_y)
    tree = cst.parse_statement(src)
    env = SymbolTable(locals(), globals())
    f1 = exec_def_in_file(tree, env)
    f2 = apply_passes([ssa(strict)])(f1)
예제 #7
0
def test_basic_if(strict, a, b):
    if a == b == 'return':
        final = ''
    else:
        final = 'return r'

    src = basic_template.format(a, b, final)
    tree = cst.parse_statement(src)
    env = SymbolTable({}, {})
    basic = exec_def_in_file(tree, env)
    ssa_basic = apply_passes([ssa(strict)])(basic)

    for x in (False, True):
        assert basic(x) == ssa_basic(x)
예제 #8
0
def test_attr():
    bar = namedtuple('bar', ['x', 'y'])

    def f1(x, y):
        z = bar(1, 0)
        if x:
            a = z
        else:
            a = y
        a.x = 3
        return a

    f2 = apply_passes([ssa(False)])(f1)
    assert inspect.getsource(f2) == '''\
예제 #9
0
def test_nested(strict, a, b, c, d):
    if a == b == c == d == 'return':
        final = ''
    else:
        final = 'return r'

    src = nested_template.format(a, b, c, d, final)
    tree = cst.parse_statement(src)
    env = SymbolTable({}, {})
    nested = exec_def_in_file(tree, env)
    ssa_nested = apply_passes([ssa(strict)])(nested)

    for x in (False, True):
        for y in (False, True):
            assert nested(x, y) == ssa_nested(x, y)
예제 #10
0
def test_double_nested_function_call():
    def bar(x):
        return x

    def baz(x):
        return x + 1

    deco = apply_passes([ssa()])
    deco.env.locals['deco'] = deco

    @deco               # 1
    def foo(a, b, c):   # 2
        if b:           # 3
            a = bar(a)  # 4
        else:           # 5
            a = bar(a)  # 6
        if c:           # 7
            b = bar(b)  # 8
        else:           # 9
            b = bar(b)  # 10
        return a, b     # 11
    assert inspect.getsource(foo) == '''\
def foo(a, b, c):   # 2
    _cond_0 = b
    a_0 = bar(a)  # 4
    a_1 = bar(a)  # 6
    a_2 = a_0 if _cond_0 else a_1
    _cond_1 = c
    b_0 = bar(b)  # 8
    b_1 = bar(b)  # 10
    b_2 = b_0 if _cond_1 else b_1
    __0_return_0 = a_2, b_2     # 11
    return __0_return_0
'''

    symbol_tables = deco.metadata['SYMBOL-TABLE']
    assert len(symbol_tables) == 1
    assert symbol_tables[0][0] == ssa
    symbol_table = symbol_tables[0][1]
    gold_table = {i: {
        'a': 'a' if i < 4 else 'a_0' if i <  6 else 'a_1' if i <  7 else 'a_2',
        'b': 'b' if i < 8 else 'b_0' if i < 10 else 'b_1' if i < 11 else 'b_2',
        'c': 'c',
        } for i in range(2, 12)}
    assert symbol_table == gold_table
예제 #11
0
def test_reassign_arg():
    def bar(x):
        return x

    deco = apply_passes([ssa()])
    deco.env.locals['deco'] = deco

    @deco
    def foo(a, b):
        if b:
            a = len(a)
        return a
    assert inspect.getsource(foo) == '''\
def foo(a, b):
    _cond_0 = b
    a_0 = len(a)
    a_1 = a_0 if _cond_0 else a
    __0_return_0 = a_1
    return __0_return_0
'''
    symbol_tables = deco.metadata['SYMBOL-TABLE']
    assert len(symbol_tables) == 1
    assert symbol_tables[0][0] == ssa
    symbol_table = symbol_tables[0][1]
    assert symbol_table == {
        2: {
            'a': 'a',
            'b': 'b',
        },
        3: {
            'a': 'a',
            'b': 'b',
        },
        4: {
            'a': 'a_0',
            'b': 'b',
        },
        5: {
            'a': 'a_1',
            'b': 'b',
        },
    }
예제 #12
0
def test_attrs_basic(strict):
    def f1(t, cond):
        old = t.x
        if cond:
            t.x = 1
        else:
            t.x = 0
        return old

    f2  = apply_passes([ssa(strict)])(f1)

    t1 = Thing()
    t2 = Thing()

    assert t1 == t2 == None
    f1(t1, True)
    assert t1 != t2
    f2(t2, True)
    assert t1 == t2 == 1
    f1(t1, False)
    assert t1 != t2
    f2(t2, False)
    assert t1 == t2 == 0
예제 #13
0
 class Counter2:
     __init__ = Counter1.__init__
     __call__ = apply_passes([ssa(strict)])(Counter1.__call__)
     get_step = Counter1.get_step