def test_imbalanced(strict, a, b, c, d): src = imbalanced_template.format(a, b, c, d) tree = cst.parse_statement(src) env = SymbolTable({}, {}) imbalanced = exec_def_in_file(tree, env) can_name_error = False for x in (False, True): for y in (False, True): try: imbalanced(x, y) except NameError: can_name_error = True break if can_name_error and strict: with pytest.raises(SyntaxError): ssa_imbalanced = apply_passes([ssa(strict)])(imbalanced) else: ssa_imbalanced = apply_passes([ssa(strict)])(imbalanced) for x in (False, True): for y in (False, True): try: assert imbalanced(x, y) == ssa_imbalanced(x, y) except NameError: assert can_name_error
def rewrite(self, tree: ast.AST, env: SymbolTable, metadata: tp.MutableMapping) -> tp.Union[tp.Callable, type]: decorators = [] first_group = True in_group = False # filter passes from the decorator list for node in reversed(tree.decorator_list): if not first_group: decorators.append(node) continue if isinstance(node, ast.Call): name = node.func.id else: assert isinstance(node, ast.Name) name = node.id deco = env[name] if in_group: if _issubclass(deco, end_rewrite): assert in_group in_group = False first_group = False elif _issubclass(deco, begin_rewrite): assert not in_group in_group = True else: decorators.append(node) tree.decorator_list = reversed(decorators) tree = ast.fix_missing_locations(tree) return exec_def_in_file(tree, env, **self.kwargs)
def test_call_in_annotations(strict, x, y): r_x = x if x else 'int' r_y = y if y else 'int' x = f': {x}' if x else x y = f': {y}' if y else y src = sig_template.format(x, y, r_x, r_y) tree = cst.parse_statement(src) env = SymbolTable(locals(), globals()) f1 = exec_def_in_file(tree, env) f2 = apply_passes([ssa(strict)])(f1)
def rewrite(self, tree: ast.AST, env: SymbolTable, metadata: tp.MutableMapping) -> tp.Union[tp.Callable, type]: # tree to exec etree = _ASTStripper.strip(tree, env, begin_rewrite, None) etree = ast.fix_missing_locations(etree) # tree to serialize stree = _ASTStripper.strip(tree, env, begin_rewrite, end_rewrite) stree = ast.fix_missing_locations(stree) return exec_def_in_file(etree, env, serialized_tree=stree, **self.kwargs)
def test_basic_if(strict, a, b): if a == b == 'return': final = '' else: final = 'return r' src = basic_template.format(a, b, final) tree = cst.parse_statement(src) env = SymbolTable({}, {}) basic = exec_def_in_file(tree, env) ssa_basic = apply_passes([ssa(strict)])(basic) for x in (False, True): assert basic(x) == ssa_basic(x)
def test_basic_if(strict, a, b): if a == b == 'return': final = '' else: final = 'return r' src = basic_template.format(a, b, final) tree = ast.parse(src).body[0] env = SymbolTable({}, {}) basic = exec_def_in_file(tree, env) ssa_basic = _do_ssa(basic, strict, dump_src=True) for x in (False, True): assert basic(x) == ssa_basic(x)
def test_nested(strict, a, b, c, d): if a == b == c == d == 'return': final = '' else: final = 'return r' src = nested_template.format(a, b, c, d, final) tree = cst.parse_statement(src) env = SymbolTable({}, {}) nested = exec_def_in_file(tree, env) ssa_nested = apply_passes([ssa(strict)])(nested) for x in (False, True): for y in (False, True): assert nested(x, y) == ssa_nested(x, y)
def test_nested(strict, a, b, c, d): if a == b == c == d == 'return': final = '' else: final = 'return r' src = nested_template.format(a, b, c, d, final) tree = ast.parse(src).body[0] env = SymbolTable({}, {}) nested = exec_def_in_file(tree, env) ssa_nested = _do_ssa(nested, strict, dump_src=True) for x in (False, True): for y in (False, True): assert nested(x, y) == ssa_nested(x, y)
def test_imbalanced(strict, a, b, c, d): src = imbalanced_template.format(a, b, c, d) tree = ast.parse(src).body[0] env = SymbolTable({}, {}) imbalanced = exec_def_in_file(tree, env) can_name_error = False for x in (False, True): for y in (False, True): try: imbalanced(x, y) except NameError: can_name_error = True break if can_name_error and strict: with pytest.raises(SyntaxError): imbalanced_ssa = _do_ssa(imbalanced, strict, dump_src=True) elif not can_name_error: imbalanced_ssa = _do_ssa(imbalanced, strict, dump_src=True) for x in (False, True): for y in (False, True): assert imbalanced(x, y) == imbalanced_ssa(x, y)
def exec(self, etree: tp.Union[cst.ClassDef, cst.FunctionDef], stree: tp.Union[cst.ClassDef, cst.FunctionDef], env: SymbolTable, metadata: tp.MutableMapping): return exec_def_in_file(etree, env, self.path, self.file_name, stree)
def exec(self, etree: ast.AST, stree: ast.AST, env: SymbolTable, metadata: tp.MutableMapping): etree = ast.fix_missing_locations(etree) stree = ast.fix_missing_locations(stree) return exec_def_in_file(etree, env, self.path, self.file_name, stree)