def make_ast( target_ast_node_count, rng = None, distribution = ( python_numbers_control_flow.DATAFLOW_FNS_DISTRIBUTION) ): """Generates an AST for this task. Args: target_ast_node_count: How many nodes to put in the AST. rng: Random state to use. distribution: Sampling distribution to use when building the AST. May also be a callable that produces a distribution given a random state. Returns: AST of a generated program. """ def root_build(body): """Given a list of statements, puts them into a function in a module.""" return gast.Module( body=[ gast.FunctionDef( name="random_function", args=_make_arguments( python_numbers_control_flow.make_name("a"), python_numbers_control_flow.make_name("b")), body=body, decorator_list=[], returns=None, type_comment=None) ], type_ignores=[]) root_template = python_numbers_control_flow.ASTWithHoles( cost=5, holes=[ top_down_refinement.Hole( python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY, python_numbers_control_flow.ASTHoleMetadata(("a", "b"), True, False, 0)) ], build=root_build) if rng is None: rng = np.random.RandomState() if callable(distribution): distribution = distribution(rng) tree = top_down_refinement.top_down_construct( root_object=root_template, target_cost=target_ast_node_count, refinement_distribution=distribution, rng=rng) # Re-parse the tree so that it is valid. This is required for program graph # analysis to work. return gast.parse(astunparse.unparse(gast.gast_to_ast(tree)))
def test_buildable(self, template): """Test that each template can be built when given acceptable arguments.""" rng = np.random.RandomState(1234) # Construct a hole that this template can always fill. hole = top_down_refinement.Hole( template.fills_type, python_numbers_control_flow.ASTHoleMetadata( names_in_scope=frozenset({"a"}), inside_function=True, inside_loop=True, op_depth=0)) self.assertTrue(template.can_fill(hole)) # Make sure we can build this object with no errors. filler = template.fill(hole, rng) dummy_values = { python_numbers_control_flow.ASTHoleType.NUMBER: (lambda: gast.Constant(value=1, kind=None)), python_numbers_control_flow.ASTHoleType.BOOL: (lambda: gast.Constant(value=True, kind=None)), python_numbers_control_flow.ASTHoleType.STMT: gast.Pass, python_numbers_control_flow.ASTHoleType.STMTS: (lambda: []), python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY: (lambda: [gast.Pass()]), python_numbers_control_flow.ASTHoleType.BLOCK: (lambda: [gast.Pass()]), } hole_values = [dummy_values[h.hole_type]() for h in filler.holes] value = filler.build(*hole_values) # Check the type of the value that was built. if template.fills_type in ( python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY, python_numbers_control_flow.ASTHoleType.BLOCK): self.assertTrue(value) for item in value: self.assertIsInstance(item, gast.stmt) elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMTS: for item in value: self.assertIsInstance(item, gast.stmt) elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMT: self.assertIsInstance(value, gast.stmt) elif template.fills_type in (python_numbers_control_flow.ASTHoleType.NUMBER, python_numbers_control_flow.ASTHoleType.BOOL): self.assertIsInstance(value, gast.expr) else: raise NotImplementedError(f"Unexpected fill type {template.fills_type}; " "please update this test.") # Check that cost reflects number of AST nodes. total_cost = 0 if isinstance(value, gast.AST): for _ in gast.walk(value): total_cost += 1 else: for item in value: for _ in gast.walk(item): total_cost += 1 self.assertEqual(template.required_cost, total_cost) cost_without_holes = total_cost - sum( python_numbers_control_flow.ALL_COSTS[h.hole_type] for h in filler.holes) self.assertEqual(filler.cost, cost_without_holes) # Check determinism for _ in range(20): rng = np.random.RandomState(1234) redo_value = template.fill(hole, rng).build(*hole_values) if isinstance(value, list): self.assertEqual([gast.dump(v) for v in value], [gast.dump(v) for v in redo_value]) else: self.assertEqual(gast.dump(value), gast.dump(redo_value))