def test_ast_equal(): src = """ def sample_fn(x, y, foo='bar', **kw): if (foo == 'bar'): return (x + y) else: return kw['zzz'] """ # Different node type (`-` instead of `+`) different_node = """ def sample_fn(x, y, foo='bar', **kw): if (foo == 'bar'): return (x - y) else: return kw['zzz'] """ # Different value in a node ('zzy' instead of 'zzz') different_value = """ def sample_fn(x, y, foo='bar', **kw): if (foo == 'bar'): return (x + y) else: return kw['zzy'] """ # Additional element in a body different_length = """ def sample_fn(x, y, foo='bar', **kw): if (foo == 'bar'): return (x + y) return 1 else: return kw['zzz'] """ tree = ast.parse(unindent(src)) different_node = ast.parse(unindent(different_node)) different_value = ast.parse(unindent(different_value)) different_length = ast.parse(unindent(different_length)) assert ast_equal(tree, tree) assert not ast_equal(tree, different_node) assert not ast_equal(tree, different_value) assert not ast_equal(tree, different_length)
def prune_cfg(node, bindings): while True: new_node = node for func in (remove_unreachable_statements, simplify_loops, remove_unreachable_branches): new_node = func(new_node, ctx=dict(bindings=bindings)) if ast_equal(new_node, node): break node = new_node return new_node, bindings
def optimized_ast(tree, constants): while True: new_tree = tree new_constants = constants for func in (inline_functions, fold, prune_cfg, prune_assignments): new_tree, new_constants = func(new_tree, new_constants) if ast_equal(new_tree, tree) and new_constants == constants: break tree = new_tree constants = new_constants return new_tree, new_constants
def assert_ast_equal(test_ast, expected_ast, print_ast=True): ''' Check that test_ast is equal to expected_ast, printing helpful error message if they are not equal ''' equal = ast_equal(test_ast, expected_ast) if not equal: if print_ast: expected_ast_str = ast_to_string(expected_ast) test_ast_str = ast_to_string(test_ast) print_diff(test_ast_str, expected_ast_str) expected_source = ast_to_source(expected_ast) test_source = ast_to_source(test_ast) print_diff(test_source, expected_source) assert equal
def peval_compare(state, ctx, node): if len(node.ops) == 1: return peval_single_compare(state, ctx, node.ops[0], node.left, node.comparators[0]) values = [] for value_node in [node.left] + node.comparators: state, value = _peval_expression(state, value_node, ctx) values.append(value) pair_values = [] lefts = [node.left] + node.comparators[:-1] rights = node.comparators for left, op, right in zip(lefts, node.ops, rights): state, pair_value = peval_single_compare(state, ctx, op, left, right) pair_values.append(pair_value) state, result = peval_boolop(state, ctx, ast.And(), pair_values) if is_known_value(result): return state, result if type(result) != ast.BoolOp: return state, result # Glueing non-evaluated comparisons back together. nodes = [result.values[0]] for value in result.values[1:]: last_node = nodes[-1] if (type(last_node) == ast.Compare and type(value) == ast.Compare and ast_equal(last_node.comparators[-1], value.left)): nodes[-1] = ast.Compare( left=last_node.left, ops=last_node.ops + value.ops, comparators=last_node.comparators + value.comparators) else: nodes.append(value) if len(nodes) == 1: return state, nodes[0] else: return state, ast.BoolOp(op=ast.And(), values=nodes)
def assert_ast_equal(test_ast, expected_ast, print_ast=True): """ Check that test_ast is equal to expected_ast, printing helpful error message if they are not equal """ equal = ast_equal(test_ast, expected_ast) if not equal: if print_ast: expected_ast_str = astunparse.dump(expected_ast) test_ast_str = astunparse.dump(test_ast) print_diff(test_ast_str, expected_ast_str) expected_source = normalize_source(unparse(expected_ast)) test_source = normalize_source(unparse(test_ast)) print_diff(test_source, expected_source) assert equal