def fill(self, hole, rng): return top_down_refinement.ThingWithHoles( 1, [top_down_refinement.Hole("thing", None)], lambda t: "a" + t)
def test_buildable(self, template): """Test that each template can be built when given acceptable arguments.""" rng = np.random.RandomState(1234) # Construct a hole that this template can always fill. hole = top_down_refinement.Hole( template.fills_type, python_numbers_control_flow.ASTHoleMetadata( names_in_scope=frozenset({"a"}), inside_function=True, inside_loop=True, op_depth=0)) self.assertTrue(template.can_fill(hole)) # Make sure we can build this object with no errors. filler = template.fill(hole, rng) dummy_values = { python_numbers_control_flow.ASTHoleType.NUMBER: (lambda: gast.Constant(value=1, kind=None)), python_numbers_control_flow.ASTHoleType.BOOL: (lambda: gast.Constant(value=True, kind=None)), python_numbers_control_flow.ASTHoleType.STMT: gast.Pass, python_numbers_control_flow.ASTHoleType.STMTS: (lambda: []), python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY: (lambda: [gast.Pass()]), python_numbers_control_flow.ASTHoleType.BLOCK: (lambda: [gast.Pass()]), } hole_values = [dummy_values[h.hole_type]() for h in filler.holes] value = filler.build(*hole_values) # Check the type of the value that was built. if template.fills_type in ( python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY, python_numbers_control_flow.ASTHoleType.BLOCK): self.assertTrue(value) for item in value: self.assertIsInstance(item, gast.stmt) elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMTS: for item in value: self.assertIsInstance(item, gast.stmt) elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMT: self.assertIsInstance(value, gast.stmt) elif template.fills_type in (python_numbers_control_flow.ASTHoleType.NUMBER, python_numbers_control_flow.ASTHoleType.BOOL): self.assertIsInstance(value, gast.expr) else: raise NotImplementedError(f"Unexpected fill type {template.fills_type}; " "please update this test.") # Check that cost reflects number of AST nodes. total_cost = 0 if isinstance(value, gast.AST): for _ in gast.walk(value): total_cost += 1 else: for item in value: for _ in gast.walk(item): total_cost += 1 self.assertEqual(template.required_cost, total_cost) cost_without_holes = total_cost - sum( python_numbers_control_flow.ALL_COSTS[h.hole_type] for h in filler.holes) self.assertEqual(filler.cost, cost_without_holes) # Check determinism for _ in range(20): rng = np.random.RandomState(1234) redo_value = template.fill(hole, rng).build(*hole_values) if isinstance(value, list): self.assertEqual([gast.dump(v) for v in value], [gast.dump(v) for v in redo_value]) else: self.assertEqual(gast.dump(value), gast.dump(redo_value))
def fill(self, hole, rng): return top_down_refinement.ThingWithHoles( 1, [top_down_refinement.Hole("bar", None)], lambda bar: "foo" + bar)
def test_random_sampling(self): """Test that holes and templates are chosen proportional to weights.""" class A1Template(top_down_refinement.HoleFillerTemplate): fills_type = "a" required_cost = 2 def fill(self, hole, rng): return top_down_refinement.ThingWithHoles(2, [], lambda: "a1") class A2Template(top_down_refinement.HoleFillerTemplate): fills_type = "a" required_cost = 2 def fill(self, hole, rng): return top_down_refinement.ThingWithHoles(2, [], lambda: "a2") class AFallbackTemplate(top_down_refinement.HoleFillerTemplate): fills_type = "a" required_cost = 1 def fill(self, hole, rng): return top_down_refinement.ThingWithHoles(2, [], lambda: "af") class B1Template(top_down_refinement.HoleFillerTemplate): fills_type = "b" required_cost = 2 def fill(self, hole, rng): return top_down_refinement.ThingWithHoles(2, [], lambda: "b1") class BFallbackTemplate(top_down_refinement.HoleFillerTemplate): fills_type = "b" required_cost = 1 def fill(self, hole, rng): return top_down_refinement.ThingWithHoles(1, [], lambda: "bf") counts = collections.Counter() rng = np.random.RandomState(1234) trials = 10000 for _ in range(trials): result = top_down_refinement.top_down_construct( root_object=top_down_refinement.ThingWithHoles( 0, [ top_down_refinement.Hole("a", None), top_down_refinement.Hole("b", None) ], lambda a, b: (a, b)), target_cost=3, refinement_distribution=top_down_refinement. RefinementDistribution(weighted_templates=[ top_down_refinement.WeightedTemplate(A1Template(), 1), top_down_refinement.WeightedTemplate(A2Template(), 2), top_down_refinement.WeightedTemplate(AFallbackTemplate(), 1, precedence=0), top_down_refinement.WeightedTemplate(B1Template(), 1), top_down_refinement.WeightedTemplate(BFallbackTemplate(), 1, precedence=0), ], hole_selection_weights={ "a": 3, "b": 1 }), rng=rng) counts[result] += 1 # Assert that counts are within one standard deviation of the mean (which is # sufficient for the fixed seed above). p_a1_bf = (3 / 4) * (1 / 3) np.testing.assert_allclose(counts["a1", "bf"], trials * p_a1_bf, atol=np.sqrt(trials * p_a1_bf * (1 - p_a1_bf))) p_a2_bf = (3 / 4) * (2 / 3) np.testing.assert_allclose(counts["a2", "bf"], trials * p_a2_bf, atol=np.sqrt(trials * p_a2_bf * (1 - p_a2_bf))) p_af_b1 = 1 / 4 np.testing.assert_allclose(counts["af", "b1"], trials * p_af_b1, atol=np.sqrt(trials * p_af_b1 * (1 - p_af_b1)))