def visit_If(self, node): body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) # Note: this information needs to be extracted before the body conversion # that happens in the call to generic_visit below, because the conversion # generates nodes that lack static analysis annotations. need_alias_in_body = self._determine_aliased_symbols( body_scope, defined_in, node.body) need_alias_in_orelse = self._determine_aliased_symbols( orelse_scope, defined_in, node.orelse) node = self.generic_visit(node) modified_in_cond = body_scope.modified | orelse_scope.modified returned_from_cond = set() composites = set() for s in modified_in_cond: if s in live_out and not s.is_composite(): returned_from_cond.add(s) if s.is_composite(): # Special treatment for compound objects, always return them. # This allows special handling within the if_stmt itself. # For example, in TensorFlow we need to restore the state of composite # symbols to ensure that only effects from the executed branch are seen. composites.add(s) created_in_body = body_scope.modified & returned_from_cond - defined_in created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in basic_created_in_body = tuple(s for s in created_in_body if not s.is_composite()) basic_created_in_orelse = tuple(s for s in created_in_orelse if not s.is_composite()) # These variables are defined only in a single branch. This is fine in # Python so we pass them through. Another backend, e.g. Tensorflow, may need # to handle these cases specially or throw an Error. possibly_undefined = (set(basic_created_in_body) ^ set(basic_created_in_orelse)) # Alias the closure variables inside the conditional functions, to allow # the functions access to the respective variables. # We will alias variables independently for body and orelse scope, # because different branches might write different variables. aliased_body_orig_names = tuple(need_alias_in_body) aliased_orelse_orig_names = tuple(need_alias_in_orelse) aliased_body_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) for s in aliased_body_orig_names) aliased_orelse_new_names = tuple( self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) for s in aliased_orelse_orig_names) alias_body_map = dict( zip(aliased_body_orig_names, aliased_body_new_names)) alias_orelse_map = dict( zip(aliased_orelse_orig_names, aliased_orelse_new_names)) node_body = ast_util.rename_symbols(node.body, alias_body_map) node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) all_referenced = body_scope.referenced | orelse_scope.referenced state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced) state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced) returned_from_cond = tuple(returned_from_cond) composites = tuple(composites) if returned_from_cond: if len(returned_from_cond) == 1: cond_results = returned_from_cond[0] else: cond_results = gast.Tuple( [s.ast() for s in returned_from_cond], None) returned_from_body = tuple( alias_body_map[s] if s in need_alias_in_body else s for s in returned_from_cond) returned_from_orelse = tuple( alias_orelse_map[s] if s in need_alias_in_orelse else s for s in returned_from_cond) else: # When the cond would return no value, we leave the cond called without # results. That in turn should trigger the side effect guards. The # branch functions will return a dummy value that ensures cond # actually has some return value as well. cond_results = None # TODO(mdan): Replace with None once side_effect_guards is retired. returned_from_body = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name), ) returned_from_orelse = (templates.replace_as_expression( 'ag__.match_staging_level(1, cond_var_name)', cond_var_name=cond_var_name), ) cond_assign = self.create_assignment(cond_var_name, node.test) body_def = self._create_cond_branch( body_name, aliased_orig_names=aliased_body_orig_names, aliased_new_names=aliased_body_new_names, body=node_body, returns=returned_from_body) orelse_def = self._create_cond_branch( orelse_name, aliased_orig_names=aliased_orelse_orig_names, aliased_new_names=aliased_orelse_new_names, body=node_orelse, returns=returned_from_orelse) undefined_assigns = self._create_undefined_assigns(possibly_undefined) composite_defs = self._create_state_functions(composites, [], state_getter_name, state_setter_name) basic_symbol_names = tuple( gast.Constant(str(symbol), kind=None) for symbol in returned_from_cond) composite_symbol_names = tuple( gast.Constant(str(symbol), kind=None) for symbol in composites) cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, orelse_name, state_getter_name, state_setter_name, basic_symbol_names, composite_symbol_names) if_ast = (undefined_assigns + composite_defs + body_def + orelse_def + cond_assign + cond_expr) return if_ast
class Square(Transformation): """ Replaces **2 by a call to numpy.square. >>> import gast as ast >>> from pythran import passmanager, backend >>> node = ast.parse('a**2') >>> pm = passmanager.PassManager("test") >>> _, node = pm.apply(Square, node) >>> print(pm.dump(backend.Python, node)) import numpy as __pythran_import_numpy __pythran_import_numpy.square(a) >>> node = ast.parse('__pythran_import_numpy.power(a,2)') >>> pm = passmanager.PassManager("test") >>> _, node = pm.apply(Square, node) >>> print(pm.dump(backend.Python, node)) import numpy as __pythran_import_numpy __pythran_import_numpy.square(a) """ POW_PATTERN = ast.BinOp(AST_any(), ast.Pow(), ast.Constant(2, None)) POWER_PATTERN = ast.Call( ast.Attribute(ast.Name(mangle('numpy'), ast.Load(), None, None), 'power', ast.Load()), [AST_any(), ast.Constant(2, None)], []) def __init__(self): Transformation.__init__(self) def replace(self, value): self.update = self.need_import = True module_name = ast.Name(mangle('numpy'), ast.Load(), None, None) return ast.Call(ast.Attribute(module_name, 'square', ast.Load()), [value], []) def visit_Module(self, node): self.need_import = False self.generic_visit(node) if self.need_import: import_alias = ast.alias(name='numpy', asname=mangle('numpy')) importIt = ast.Import(names=[import_alias]) node.body.insert(0, importIt) return node def expand_pow(self, node, n): if n == 0: return ast.Constant(1, None) elif n == 1: return node else: node_square = self.replace(node) node_pow = self.expand_pow(node_square, n >> 1) if n & 1: return ast.BinOp(node_pow, ast.Mult(), copy.deepcopy(node)) else: return node_pow def visit_BinOp(self, node): self.generic_visit(node) if ASTMatcher(Square.POW_PATTERN).search(node): return self.replace(node.left) elif isinstance(node.op, ast.Pow) and isnum(node.right): n = node.right.value if int(n) == n and n > 0: return self.expand_pow(node.left, n) else: return node else: return node def visit_Call(self, node): self.generic_visit(node) if ASTMatcher(Square.POWER_PATTERN).search(node): return self.replace(node.args[0]) else: return node
def sub(): return ast.BinOp(left=Placeholder(0), op=ast.Pow(), right=ast.Constant(2, None))
class PyASTGraphsTest(parameterized.TestCase): def test_schema(self): """Check that the generated schema is reasonable.""" # Building should succeed schema = py_ast_graphs.SCHEMA # Should have one graph node for each AST node, plus sorted helpers. seq_helpers = [ "BoolOp_values-seq-helper", "Call_args-seq-helper", "For_body-seq-helper", "FunctionDef_body-seq-helper", "If_body-seq-helper", "If_orelse-seq-helper", "Module_body-seq-helper", "While_body-seq-helper", "arguments_args-seq-helper", ] expected_keys = [ *py_ast_graphs.PY_AST_SPECS.keys(), *seq_helpers, ] self.assertEqual(list(schema.keys()), expected_keys) # Check a few elements expected_fundef_in_edges = { "parent_in", "returns_in", "returns_missing", "body_in", "args_in", } expected_fundef_out_edges = { "parent_out", "returns_out", "body_out_all", "body_out_first", "body_out_last", "args_out", } self.assertEqual(set(schema["FunctionDef"].in_edges), expected_fundef_in_edges) self.assertEqual(set(schema["FunctionDef"].out_edges), expected_fundef_out_edges) expected_expr_in_edges = {"parent_in", "value_in"} expected_expr_out_edges = {"parent_out", "value_out"} self.assertEqual(set(schema["Expr"].in_edges), expected_expr_in_edges) self.assertEqual(set(schema["Expr"].out_edges), expected_expr_out_edges) expected_seq_helper_in_edges = { "parent_in", "item_in", "next_in", "next_missing", "prev_in", "prev_missing", } expected_seq_helper_out_edges = { "parent_out", "item_out", "next_out", "prev_out", } for seq_helper in seq_helpers: self.assertEqual(set(schema[seq_helper].in_edges), expected_seq_helper_in_edges) self.assertEqual(set(schema[seq_helper].out_edges), expected_seq_helper_out_edges) def test_ast_graph_conforms_to_schema(self): # Some example code using a few different syntactic constructs, to cover # a large set of nodes in the schema root = gast.parse( textwrap.dedent("""\ def foo(n): if n <= 1: return 1 else: return foo(n-1) + foo(n-2) def bar(m, n) -> int: x = n for i in range(m): if False: continue x = x + i while True: break return x x0 = 1 + 2 - 3 * 4 / 5 x1 = (1 == 2) and (3 < 4) and (5 > 6) x2 = (7 <= 8) and (9 >= 10) or (11 != 12) x2 = bar(13, 14 + 15) """)) graph, _ = py_ast_graphs.py_ast_to_graph(root) # Graph should match the schema schema_util.assert_conforms_to_schema(graph, py_ast_graphs.SCHEMA) def test_ast_graph_nodes(self): """Check node IDs, node types, and forward mapping.""" root = gast.parse( textwrap.dedent("""\ pass def foo(n): if n <= 1: return 1 """)) graph, forward_map = py_ast_graphs.py_ast_to_graph(root) # pytype: disable=attribute-error self.assertIn("root__Module", graph) self.assertEqual(graph["root__Module"].node_type, "Module") self.assertEqual(forward_map[id(root)], "root__Module") self.assertIn("root_body_1__Module_body-seq-helper", graph) self.assertEqual( graph["root_body_1__Module_body-seq-helper"].node_type, "Module_body-seq-helper") self.assertIn("root_body_1_item_body_0_item__If", graph) self.assertEqual(graph["root_body_1_item_body_0_item__If"].node_type, "If") self.assertEqual(forward_map[id(root.body[1].body[0])], "root_body_1_item_body_0_item__If") self.assertIn("root_body_1_item_body_0_item_test_left__Name", graph) self.assertEqual( graph["root_body_1_item_body_0_item_test_left__Name"].node_type, "Name") self.assertEqual(forward_map[id(root.body[1].body[0].test.left)], "root_body_1_item_body_0_item_test_left__Name") # pytype: enable=attribute-error def test_ast_graph_unique_field_edges(self): """Test that edges for unique fields are correct.""" root = gast.parse("print(1)") graph, _ = py_ast_graphs.py_ast_to_graph(root) self.assertEqual( graph["root_body_0_item__Expr"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item_value__Call"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual( graph["root_body_0_item_value__Call"].out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Expr"), in_edge=graph_types.InEdgeType("value_in")) ]) def test_ast_graph_optional_field_edges(self): """Test that edges for optional fields are correct.""" root = gast.parse("return 1\nreturn") graph, _ = py_ast_graphs.py_ast_to_graph(root) self.assertEqual( graph["root_body_0_item__Return"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0_item_value__Constant"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual( graph["root_body_0_item_value__Constant"].out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Return"), in_edge=graph_types.InEdgeType("value_in")) ]) self.assertEqual( graph["root_body_1_item__Return"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_1_item__Return"), in_edge=graph_types.InEdgeType("value_missing")) ]) def test_ast_graph_sequence_field_edges(self): """Test that edges for sequence fields are correct. Note that sequence fields produce connections between three nodes: the parent, the helper node, and the child. """ root = gast.parse( textwrap.dedent("""\ print(1) print(2) print(3) print(4) print(5) print(6) """)) graph, _ = py_ast_graphs.py_ast_to_graph(root) # Child edges from the parent node node = graph["root__Module"] self.assertLen(node.out_edges["body_out_all"], 6) self.assertEqual(node.out_edges["body_out_first"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual(node.out_edges["body_out_last"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_5__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("parent_in")) ]) # Edges from the sequence helper node = graph["root_body_0__Module_body-seq-helper"] self.assertEqual(node.out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root__Module"), in_edge=graph_types.InEdgeType("body_in")) ]) self.assertEqual(node.out_edges["item_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Expr"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual(node.out_edges["prev_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("prev_missing")) ]) self.assertEqual(node.out_edges["next_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_1__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("prev_in")) ]) # Parent edge of the item node = graph["root_body_0_item__Expr"] self.assertEqual(node.out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("item_in")) ]) @parameterized.named_parameters( { "testcase_name": "unexpected_type", "ast": gast.Subscript(value=None, slice=None, ctx=None), "expected_error": "Unknown AST node type 'Subscript'", }, { "testcase_name": "too_many_unique", "ast": gast.Assign(targets=[ gast.Name("foo", gast.Store(), None, None), gast.Name("bar", gast.Store(), None, None) ], value=gast.Constant(True, None)), "expected_error": "Expected 1 child for field 'targets' of node .*; got 2", }, { "testcase_name": "missing_unique", "ast": gast.Assign(targets=[], value=gast.Constant(True, None)), "expected_error": "Expected 1 child for field 'targets' of node .*; got 0", }, { "testcase_name": "too_many_optional", "ast": gast.Return(value=[ gast.Name("foo", gast.Load(), None, None), gast.Name("bar", gast.Load(), None, None) ]), "expected_error": "Expected at most 1 child for field 'value' of node .*; got 2", }) def test_invalid_graphs(self, ast, expected_error): with self.assertRaisesRegex(ValueError, expected_error): py_ast_graphs.py_ast_to_graph(ast)
def visit_BinOp(self, node): if not isinstance(node.op, ast.Mod): return self.generic_visit(node) # check that right is a name defined once outside of loop # TODO: handle expression instead of names if not isinstance(node.right, ast.Name): return self.generic_visit(node) right_def = self.single_def(node.right) if not right_def: return self.generic_visit(node) if self.range_values[node.right.id].low < 0: return self.generic_visit(node) # same for lhs if not isinstance(node.left, ast.Name): return self.generic_visit(node) head = self.single_def(node.left) if not head: return self.generic_visit(node) # check lhs is the actual index of a loop loop = self.ancestors[head][-1] if not isinstance(loop, ast.For): return self.generic_visit(node) if not isinstance(loop.iter, ast.Call): return self.generic_visit(node) # make sure rhs is defined out of the loop if loop in self.ancestors[right_def]: return self.generic_visit(node) # gather range informations range_ = None for alias in self.aliases[loop.iter.func]: if alias is MODULES['builtins']['range']: range_ = alias else: break if range_ is None: return self.generic_visit(node) # everything is setup for the transformation! new_id = node.left.id + '_m' i = 0 while new_id in self.identifiers: new_id = '{}_m{}'.format(node.left.id, i) i += 1 rargs = range_.args.args lower = rargs[0] if len(rargs) > 1 else ast.Constant(0, None) header = ast.Assign([ast.Name(new_id, ast.Store(), None, None)], ast.BinOp( ast.BinOp(deepcopy(lower), ast.Sub(), ast.Constant(1, None)), ast.Mod(), deepcopy(node.right)), None) incr = ast.BinOp(ast.Name(new_id, ast.Load(), None, None), ast.Add(), ast.Constant(1, None)) step = ast.Assign([ast.Name(new_id, ast.Store(), None, None)], ast.IfExp( ast.Compare(incr, [ast.Eq()], [deepcopy(node.right)]), ast.Constant(0, None), deepcopy(incr)), None) self.loops_mod.setdefault(loop, []).append((header, step)) self.update = True return ast.Name(new_id, ast.Load(), None, None)
def test_iter_child_nodes(self): tree = gast.UnaryOp(gast.USub(), gast.Constant(value=1, kind=None)) self.assertEqual(len(list(gast.iter_fields(tree))), 2)
def visit_If(self, node): if node.test not in self.static_expressions: return self.generic_visit(node) imported_ids = self.gather(ImportedIds, node) assigned_ids_left = self.escaping_ids(node, node.body) assigned_ids_right = self.escaping_ids(node, node.orelse) assigned_ids_both = assigned_ids_left.union(assigned_ids_right) imported_ids.update(i for i in assigned_ids_left if i not in assigned_ids_right) imported_ids.update(i for i in assigned_ids_right if i not in assigned_ids_left) imported_ids = sorted(imported_ids) assigned_ids = sorted(assigned_ids_both) fbody = self.make_fake(node.body) true_has_return = self.gather(HasReturn, fbody) true_has_break = self.gather(HasBreak, fbody) true_has_cont = self.gather(HasContinue, fbody) felse = self.make_fake(node.orelse) false_has_return = self.gather(HasReturn, felse) false_has_break = self.gather(HasBreak, felse) false_has_cont = self.gather(HasContinue, felse) has_return = true_has_return or false_has_return has_break = true_has_break or false_has_break has_cont = true_has_cont or false_has_cont self.generic_visit(node) func_true = outline(self.true_name(), imported_ids, assigned_ids, node.body, has_return, has_break, has_cont) func_false = outline(self.false_name(), imported_ids, assigned_ids, node.orelse, has_return, has_break, has_cont) self.new_functions.extend((func_true, func_false)) actual_call = self.make_dispatcher(node.test, func_true, func_false, imported_ids) # variable modified within the static_if expected_return = [ ast.Name(ii, ast.Store(), None, None) for ii in assigned_ids ] self.update = True # name for various variables resulting from the static_if n = len(self.new_functions) status_n = "$status{}".format(n) return_n = "$return{}".format(n) cont_n = "$cont{}".format(n) if has_return: cfg = self.cfgs[-1] always_return = all( isinstance(x, (ast.Return, ast.Yield)) for x in cfg[node]) always_return &= true_has_return and false_has_return fast_return = [ ast.Name(status_n, ast.Store(), None, None), ast.Name(return_n, ast.Store(), None, None), ast.Name(cont_n, ast.Store(), None, None) ] if always_return: return [ ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call), ast.Return(ast.Name(return_n, ast.Load(), None, None)) ] else: cont_ass = self.make_control_flow_handlers( cont_n, status_n, expected_return, has_cont, has_break) cmpr = ast.Compare(ast.Name(status_n, ast.Load(), None, None), [ast.Eq()], [ast.Constant(EARLY_RET, None)]) return [ ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call), ast.If(cmpr, [ ast.Return(ast.Name(return_n, ast.Load(), None, None)) ], cont_ass) ] elif has_break or has_cont: cont_ass = self.make_control_flow_handlers(cont_n, status_n, expected_return, has_cont, has_break) fast_return = [ ast.Name(status_n, ast.Store(), None, None), ast.Name(cont_n, ast.Store(), None, None) ] return [ ast.Assign([ast.Tuple(fast_return, ast.Store())], actual_call) ] + cont_ass elif expected_return: return ast.Assign([ast.Tuple(expected_return, ast.Store())], actual_call) else: return ast.Expr(actual_call)
def visit_While(self, node): node = self.generic_visit(node) (basic_loop_vars, composite_loop_vars, reserved_symbols, possibly_undefs) = self._get_loop_vars( node, anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified) loop_vars, loop_vars_ast_tuple = self._loop_var_constructs( basic_loop_vars) state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) state_functions = self._create_state_functions(composite_loop_vars, state_getter_name, state_setter_name) basic_symbol_names = tuple( gast.Constant(str(symbol), kind=None) for symbol in basic_loop_vars) composite_symbol_names = tuple( gast.Constant(str(symbol), kind=None) for symbol in composite_loop_vars) opts = self._create_loop_options(node) # TODO(mdan): Use a single template. # If the body and test functions took a single tuple for loop_vars, instead # of *loop_vars, then a single template could be used. if loop_vars: template = """ state_functions def body_name(loop_vars): body return loop_vars, def test_name(loop_vars): return test loop_vars_ast_tuple = ag__.while_stmt( test_name, body_name, state_getter_name, state_setter_name, (loop_vars,), (basic_symbol_names,), (composite_symbol_names,), opts) """ node = templates.replace( template, loop_vars=loop_vars, loop_vars_ast_tuple=loop_vars_ast_tuple, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=node.test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node.body, state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, basic_symbol_names=basic_symbol_names, composite_symbol_names=composite_symbol_names, opts=opts) else: template = """ state_functions def body_name(): body return () def test_name(): return test ag__.while_stmt( test_name, body_name, state_getter_name, state_setter_name, (), (), (composite_symbol_names,), opts) """ node = templates.replace( template, test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), test=node.test, body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), body=node.body, state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, composite_symbol_names=composite_symbol_names, opts=opts) undefined_assigns = self._create_undefined_assigns(possibly_undefs) return undefined_assigns + node
def visit_For(self, node): node = self.generic_visit(node) (basic_loop_vars, composite_loop_vars, reserved_symbols, possibly_undefs) = self._get_loop_vars( node, (anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified | anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE).modified)) loop_vars, loop_vars_ast_tuple = self._loop_var_constructs( basic_loop_vars) body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols) state_getter_name = self.ctx.namer.new_symbol('get_state', reserved_symbols) state_setter_name = self.ctx.namer.new_symbol('set_state', reserved_symbols) state_functions = self._create_state_functions(composite_loop_vars, state_getter_name, state_setter_name) if anno.hasanno(node, 'extra_test'): extra_test = anno.getanno(node, 'extra_test') extra_test_name = self.ctx.namer.new_symbol( 'extra_test', reserved_symbols) template = """ def extra_test_name(loop_vars): return extra_test_expr """ extra_test_function = templates.replace( template, extra_test_name=extra_test_name, loop_vars=loop_vars, extra_test_expr=extra_test) else: extra_test_name = parser.parse_expression('None') extra_test_function = [] # Workaround for PEP-3113 # iterates_var holds a single variable with the iterates, which may be a # tuple. iterates_var_name = self.ctx.namer.new_symbol('iterates', reserved_symbols) template = """ iterates = iterates_var_name """ iterate_expansion = templates.replace( template, iterates=node.target, iterates_var_name=iterates_var_name) undefined_assigns = self._create_undefined_assigns(possibly_undefs) basic_symbol_names = tuple( gast.Constant(str(symbol), kind=None) for symbol in basic_loop_vars) composite_symbol_names = tuple( gast.Constant(str(symbol), kind=None) for symbol in composite_loop_vars) opts = self._create_loop_options(node) # TODO(mdan): Use a single template. # If the body and test functions took a single tuple for loop_vars, instead # of *loop_vars, then a single template could be used. if loop_vars: template = """ undefined_assigns state_functions def body_name(iterates_var_name, loop_vars): iterate_expansion body return loop_vars, extra_test_function loop_vars_ast_tuple = ag__.for_stmt( iter_, extra_test_name, body_name, state_getter_name, state_setter_name, (loop_vars,), (basic_symbol_names,), (composite_symbol_names,), opts) """ return templates.replace( template, undefined_assigns=undefined_assigns, loop_vars=loop_vars, loop_vars_ast_tuple=loop_vars_ast_tuple, iter_=node.iter, iterate_expansion=iterate_expansion, iterates_var_name=iterates_var_name, extra_test_name=extra_test_name, extra_test_function=extra_test_function, body_name=body_name, body=node.body, state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, basic_symbol_names=basic_symbol_names, composite_symbol_names=composite_symbol_names, opts=opts) else: template = """ undefined_assigns state_functions def body_name(iterates_var_name): iterate_expansion body return () extra_test_function ag__.for_stmt( iter_, extra_test_name, body_name, state_getter_name, state_setter_name, (), (), (composite_symbol_names,), opts) """ return templates.replace( template, undefined_assigns=undefined_assigns, iter_=node.iter, iterate_expansion=iterate_expansion, iterates_var_name=iterates_var_name, extra_test_name=extra_test_name, extra_test_function=extra_test_function, body_name=body_name, body=node.body, state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, composite_symbol_names=composite_symbol_names, opts=opts)
def fill(self, hole, rng): i = rng.randint(0, 100) return ASTWithHoles(1, [], lambda: gast.Constant(value=i, kind=None))
def fill(self, hole, rng): value = rng.choice([True, False]) return ASTWithHoles(1, [], lambda: gast.Constant(value=value, kind=None))
def negate(node): if isinstance(node, ast.Name): # Not type info, could be anything :( raise UnsupportedExpression() if isinstance(node, ast.UnaryOp): # !~x <> ~x == 0 <> x == ~0 <> x == -1 if isinstance(node.op, ast.Invert): return ast.Compare(node.operand, [ast.Eq()], [ast.Constant(-1, None)]) # !!x <> x if isinstance(node.op, ast.Not): return node.operand # !+x <> +x == 0 <> x == 0 <> !x if isinstance(node.op, ast.UAdd): return node.operand # !-x <> -x == 0 <> x == 0 <> !x if isinstance(node.op, ast.USub): return node.operand if isinstance(node, ast.BoolOp): new_values = [ast.UnaryOp(ast.Not(), v) for v in node.values] # !(x or y) <> !x and !y if isinstance(node.op, ast.Or): return ast.BoolOp(ast.And(), new_values) # !(x and y) <> !x or !y if isinstance(node.op, ast.And): return ast.BoolOp(ast.Or(), new_values) if isinstance(node, ast.Compare): cmps = [ ast.Compare(x, [negate(o)], [y]) for x, o, y in zip([node.left] + node.comparators[:-1], node.ops, node.comparators) ] if len(cmps) == 1: return cmps[0] return ast.BoolOp(ast.Or(), cmps) if isinstance(node, ast.Eq): return ast.NotEq() if isinstance(node, ast.NotEq): return ast.Eq() if isinstance(node, ast.Gt): return ast.LtE() if isinstance(node, ast.GtE): return ast.Lt() if isinstance(node, ast.Lt): return ast.GtE() if isinstance(node, ast.LtE): return ast.Gt() if isinstance(node, ast.In): return ast.NotIn() if isinstance(node, ast.NotIn): return ast.In() if isinstance(node, ast.Attribute): if node.attr == 'False': return ast.Constant(True, None) if node.attr == 'True': return ast.Constant(False, None) raise UnsupportedExpression()
def visit_Bytes(self, node): new_node = gast.Constant( node.s, None, ) return gast.copy_location(new_node, node)
def test_increment_lineno(self): tree = gast.Constant(value=1, kind=None) tree.lineno = 1 gast.increment_lineno(tree) self.assertEqual(tree.lineno, 2)
def test_buildable(self, template): """Test that each template can be built when given acceptable arguments.""" rng = np.random.RandomState(1234) # Construct a hole that this template can always fill. hole = top_down_refinement.Hole( template.fills_type, python_numbers_control_flow.ASTHoleMetadata(names_in_scope=("a", ), inside_function=True, inside_loop=True, op_depth=0)) self.assertTrue(template.can_fill(hole)) # Make sure we can build this object with no errors. filler = template.fill(hole, rng) dummy_values = { python_numbers_control_flow.ASTHoleType.NUMBER: (lambda: gast.Constant(value=1, kind=None)), python_numbers_control_flow.ASTHoleType.BOOL: (lambda: gast.Constant(value=True, kind=None)), python_numbers_control_flow.ASTHoleType.STMT: gast.Pass, python_numbers_control_flow.ASTHoleType.STMTS: (lambda: []), python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY: (lambda: [gast.Pass()]), python_numbers_control_flow.ASTHoleType.BLOCK: (lambda: [gast.Pass()]), } hole_values = [dummy_values[h.hole_type]() for h in filler.holes] value = filler.build(*hole_values) # Check the type of the value that was built. if template.fills_type in ( python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY, python_numbers_control_flow.ASTHoleType.BLOCK): self.assertTrue(value) for item in value: self.assertIsInstance(item, gast.stmt) elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMTS: for item in value: self.assertIsInstance(item, gast.stmt) elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMT: self.assertIsInstance(value, gast.stmt) elif template.fills_type in ( python_numbers_control_flow.ASTHoleType.NUMBER, python_numbers_control_flow.ASTHoleType.BOOL): self.assertIsInstance(value, gast.expr) else: raise NotImplementedError( f"Unexpected fill type {template.fills_type}; " "please update this test.") # Check that cost reflects number of AST nodes. total_cost = 0 if isinstance(value, gast.AST): for _ in gast.walk(value): total_cost += 1 else: for item in value: for _ in gast.walk(item): total_cost += 1 self.assertEqual(template.required_cost, total_cost) cost_without_holes = total_cost - sum( python_numbers_control_flow.ALL_COSTS[h.hole_type] for h in filler.holes) self.assertEqual(filler.cost, cost_without_holes) # Check determinism for _ in range(20): rng = np.random.RandomState(1234) redo_value = template.fill(hole, rng).build(*hole_values) if isinstance(value, list): self.assertEqual([gast.dump(v) for v in value], [gast.dump(v) for v in redo_value]) else: self.assertEqual(gast.dump(value), gast.dump(redo_value))
def make_fake(stmts): return ast.If(ast.Constant(0, None), stmts, [])
def visit_For(self, node): node = self.generic_visit(node) body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE) loop_vars, undefined, _ = self._get_block_vars( node, body_scope.modified | iter_scope.modified) undefined_assigns = self._create_undefined_assigns(undefined) nonlocal_declarations = self._create_nonlocal_declarations(loop_vars) reserved = body_scope.referenced | iter_scope.referenced state_getter_name = self.ctx.namer.new_symbol('get_state', reserved) state_setter_name = self.ctx.namer.new_symbol('set_state', reserved) state_functions = self._create_state_functions(loop_vars, nonlocal_declarations, state_getter_name, state_setter_name) opts = self._create_loop_options(node) opts.keys.append(gast.Constant('iterate_names', kind=None)) opts.values.append( gast.Constant(parser.unparse(node.target, include_encoding_marker=False), kind=None)) if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST): extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST) extra_test_name = self.ctx.namer.new_symbol('extra_test', reserved) template = """ def extra_test_name(): nonlocal_declarations return extra_test_expr """ extra_test_function = templates.replace( template, extra_test_expr=extra_test, extra_test_name=extra_test_name, loop_vars=loop_vars, nonlocal_declarations=nonlocal_declarations) else: extra_test_name = parser.parse_expression('None') extra_test_function = [] # iterate_arg_name holds a single arg with the iterates, which may be a # tuple. iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved) template = """ iterates = iterate_arg_name """ iterate_expansion = templates.replace( template, iterate_arg_name=iterate_arg_name, iterates=node.target) template = """ state_functions def body_name(iterate_arg_name): nonlocal_declarations iterate_expansion body extra_test_function undefined_assigns ag__.for_stmt( iterated, extra_test_name, body_name, state_getter_name, state_setter_name, (symbol_names,), opts) """ return templates.replace( template, body=node.body, body_name=self.ctx.namer.new_symbol('loop_body', reserved), extra_test_function=extra_test_function, extra_test_name=extra_test_name, iterate_arg_name=iterate_arg_name, iterate_expansion=iterate_expansion, iterated=node.iter, nonlocal_declarations=nonlocal_declarations, opts=opts, symbol_names=tuple( gast.Constant(str(s), kind=None) for s in loop_vars), state_functions=state_functions, state_getter_name=state_getter_name, state_setter_name=state_setter_name, undefined_assigns=undefined_assigns)
def test_iter_fields(self): tree = gast.Constant(value=1, kind=None) self.assertEqual({name for name, _ in gast.iter_fields(tree)}, {'value', 'kind'})