def test_imbalanced(strict, a, b, c, d): src = imbalanced_template.format(a, b, c, d) tree = cst.parse_statement(src) env = SymbolTable({}, {}) imbalanced = exec_def_in_file(tree, env) can_name_error = False for x in (False, True): for y in (False, True): try: imbalanced(x, y) except NameError: can_name_error = True break if can_name_error and strict: with pytest.raises(SyntaxError): ssa_imbalanced = apply_passes([ssa(strict)])(imbalanced) else: ssa_imbalanced = apply_passes([ssa(strict)])(imbalanced) for x in (False, True): for y in (False, True): try: assert imbalanced(x, y) == ssa_imbalanced(x, y) except NameError: assert can_name_error
def test_call(): def f1(x): x = 2 return g(x=x) f2 = apply_passes([ssa(False)])(f1) assert inspect.getsource(f2) == '''\
def test_nstrict(): # This function would confuse strict ssa in so many ways def f1(cond): if cond: if cond: return 0 elif not cond: z = 1 if not cond: x = z return x f2 = apply_passes([ssa(False)])(f1) assert inspect.getsource(f2) == '''\ def f1(cond): _cond_2 = cond _cond_0 = cond __0_return_0 = 0 _cond_1 = not cond z_0 = 1 _cond_3 = not cond x_0 = z_0 __0_return_1 = x_0 return __0_return_0 if _cond_2 and _cond_0 else __0_return_1 ''' for cond in [True, False]: assert f1(cond) == f2(cond)
def test_attrs_returns(strict): def f1(t, cond1, cond2): if cond1: t.x = 1 if cond2: return 0 else: t.x = 0 if cond2: return 1 return -1 f2 = apply_passes([ssa(strict)])(f1) t1 = Thing() t2 = Thing() assert t1 == t2 for _ in range(NTEST): c1 = random.randint(0, 1) c2 = random.randint(0, 1) o1 = f1(t1, c1, c2) o2 = f2(t2, c1, c2) assert o1 == o2 assert t1 == t2
def _do_ssa(func, strict, **kwargs): for dec in ( begin_rewrite(), debug(**kwargs), ssa(strict), debug(**kwargs), end_rewrite()): func = dec(func) return func
def test_call_in_annotations(strict, x, y): r_x = x if x else 'int' r_y = y if y else 'int' x = f': {x}' if x else x y = f': {y}' if y else y src = sig_template.format(x, y, r_x, r_y) tree = cst.parse_statement(src) env = SymbolTable(locals(), globals()) f1 = exec_def_in_file(tree, env) f2 = apply_passes([ssa(strict)])(f1)
def test_basic_if(strict, a, b): if a == b == 'return': final = '' else: final = 'return r' src = basic_template.format(a, b, final) tree = cst.parse_statement(src) env = SymbolTable({}, {}) basic = exec_def_in_file(tree, env) ssa_basic = apply_passes([ssa(strict)])(basic) for x in (False, True): assert basic(x) == ssa_basic(x)
def test_attr(): bar = namedtuple('bar', ['x', 'y']) def f1(x, y): z = bar(1, 0) if x: a = z else: a = y a.x = 3 return a f2 = apply_passes([ssa(False)])(f1) assert inspect.getsource(f2) == '''\
def test_nested(strict, a, b, c, d): if a == b == c == d == 'return': final = '' else: final = 'return r' src = nested_template.format(a, b, c, d, final) tree = cst.parse_statement(src) env = SymbolTable({}, {}) nested = exec_def_in_file(tree, env) ssa_nested = apply_passes([ssa(strict)])(nested) for x in (False, True): for y in (False, True): assert nested(x, y) == ssa_nested(x, y)
def test_double_nested_function_call(): def bar(x): return x def baz(x): return x + 1 deco = apply_passes([ssa()]) deco.env.locals['deco'] = deco @deco # 1 def foo(a, b, c): # 2 if b: # 3 a = bar(a) # 4 else: # 5 a = bar(a) # 6 if c: # 7 b = bar(b) # 8 else: # 9 b = bar(b) # 10 return a, b # 11 assert inspect.getsource(foo) == '''\ def foo(a, b, c): # 2 _cond_0 = b a_0 = bar(a) # 4 a_1 = bar(a) # 6 a_2 = a_0 if _cond_0 else a_1 _cond_1 = c b_0 = bar(b) # 8 b_1 = bar(b) # 10 b_2 = b_0 if _cond_1 else b_1 __0_return_0 = a_2, b_2 # 11 return __0_return_0 ''' symbol_tables = deco.metadata['SYMBOL-TABLE'] assert len(symbol_tables) == 1 assert symbol_tables[0][0] == ssa symbol_table = symbol_tables[0][1] gold_table = {i: { 'a': 'a' if i < 4 else 'a_0' if i < 6 else 'a_1' if i < 7 else 'a_2', 'b': 'b' if i < 8 else 'b_0' if i < 10 else 'b_1' if i < 11 else 'b_2', 'c': 'c', } for i in range(2, 12)} assert symbol_table == gold_table
def test_reassign_arg(): def bar(x): return x deco = apply_passes([ssa()]) deco.env.locals['deco'] = deco @deco def foo(a, b): if b: a = len(a) return a assert inspect.getsource(foo) == '''\ def foo(a, b): _cond_0 = b a_0 = len(a) a_1 = a_0 if _cond_0 else a __0_return_0 = a_1 return __0_return_0 ''' symbol_tables = deco.metadata['SYMBOL-TABLE'] assert len(symbol_tables) == 1 assert symbol_tables[0][0] == ssa symbol_table = symbol_tables[0][1] assert symbol_table == { 2: { 'a': 'a', 'b': 'b', }, 3: { 'a': 'a', 'b': 'b', }, 4: { 'a': 'a_0', 'b': 'b', }, 5: { 'a': 'a_1', 'b': 'b', }, }
def test_attrs_basic(strict): def f1(t, cond): old = t.x if cond: t.x = 1 else: t.x = 0 return old f2 = apply_passes([ssa(strict)])(f1) t1 = Thing() t2 = Thing() assert t1 == t2 == None f1(t1, True) assert t1 != t2 f2(t2, True) assert t1 == t2 == 1 f1(t1, False) assert t1 != t2 f2(t2, False) assert t1 == t2 == 0
class Counter2: __init__ = Counter1.__init__ __call__ = apply_passes([ssa(strict)])(Counter1.__call__) get_step = Counter1.get_step