def test_cost_and_precedence(self): """Test that we use highest-precedence rules unless constrained by cost.""" class BigTemplate(top_down_refinement.HoleFillerTemplate): fills_type = "thing" required_cost = 2 def fill(self, hole, rng): return top_down_refinement.ThingWithHoles( 1, [top_down_refinement.Hole("thing", None)], lambda t: "a" + t) class SmallTemplate(top_down_refinement.HoleFillerTemplate): fills_type = "thing" required_cost = 1 def fill(self, hole, rng): return top_down_refinement.ThingWithHoles(1, [], lambda: "b") result = top_down_refinement.top_down_construct( root_object=top_down_refinement.ThingWithHoles( 0, [top_down_refinement.Hole("thing", None)], lambda t: t), target_cost=10, refinement_distribution=top_down_refinement.RefinementDistribution( weighted_templates=[ top_down_refinement.WeightedTemplate( BigTemplate(), 1, precedence=1), top_down_refinement.WeightedTemplate( SmallTemplate(), 10000, precedence=0), ], hole_selection_weights={"thing": 1})) self.assertEqual(result, "aaaaaaaaab")
def test_construct_simple(self): """Simple test that we can fill different hole types.""" class FooTemplate(top_down_refinement.HoleFillerTemplate): fills_type = "foo" required_cost = 2 def fill(self, hole, rng): return top_down_refinement.ThingWithHoles( 1, [top_down_refinement.Hole("bar", None)], lambda bar: "foo" + bar) class BarTemplate(top_down_refinement.HoleFillerTemplate): fills_type = "bar" required_cost = 1 def fill(self, hole, rng): return top_down_refinement.ThingWithHoles(1, [], lambda: "bar") result = top_down_refinement.top_down_construct( root_object=top_down_refinement.ThingWithHoles( 0, [top_down_refinement.Hole("foo", None)], lambda foo: foo), target_cost=2, refinement_distribution=top_down_refinement.RefinementDistribution( weighted_templates=[ top_down_refinement.WeightedTemplate(FooTemplate(), 1), top_down_refinement.WeightedTemplate(BarTemplate(), 1), ], hole_selection_weights={ "foo": 1, "bar": 1 })) self.assertEqual(result, "foobar")
def test_random_sampling(self): """Test that holes and templates are chosen proportional to weights.""" class A1Template(top_down_refinement.HoleFillerTemplate): fills_type = "a" required_cost = 2 def fill(self, hole, rng): return top_down_refinement.ThingWithHoles(2, [], lambda: "a1") class A2Template(top_down_refinement.HoleFillerTemplate): fills_type = "a" required_cost = 2 def fill(self, hole, rng): return top_down_refinement.ThingWithHoles(2, [], lambda: "a2") class AFallbackTemplate(top_down_refinement.HoleFillerTemplate): fills_type = "a" required_cost = 1 def fill(self, hole, rng): return top_down_refinement.ThingWithHoles(2, [], lambda: "af") class B1Template(top_down_refinement.HoleFillerTemplate): fills_type = "b" required_cost = 2 def fill(self, hole, rng): return top_down_refinement.ThingWithHoles(2, [], lambda: "b1") class BFallbackTemplate(top_down_refinement.HoleFillerTemplate): fills_type = "b" required_cost = 1 def fill(self, hole, rng): return top_down_refinement.ThingWithHoles(1, [], lambda: "bf") counts = collections.Counter() rng = np.random.RandomState(1234) trials = 10000 for _ in range(trials): result = top_down_refinement.top_down_construct( root_object=top_down_refinement.ThingWithHoles( 0, [ top_down_refinement.Hole("a", None), top_down_refinement.Hole("b", None) ], lambda a, b: (a, b)), target_cost=3, refinement_distribution=top_down_refinement.RefinementDistribution( weighted_templates=[ top_down_refinement.WeightedTemplate(A1Template(), 1), top_down_refinement.WeightedTemplate(A2Template(), 2), top_down_refinement.WeightedTemplate( AFallbackTemplate(), 1, precedence=0), top_down_refinement.WeightedTemplate(B1Template(), 1), top_down_refinement.WeightedTemplate( BFallbackTemplate(), 1, precedence=0), ], hole_selection_weights={ "a": 3, "b": 1 }), rng=rng) counts[result] += 1 # Assert that counts are within one standard deviation of the mean (which is # sufficient for the fixed seed above). p_a1_bf = (3 / 4) * (1 / 3) np.testing.assert_allclose( counts["a1", "bf"], trials * p_a1_bf, atol=np.sqrt(trials * p_a1_bf * (1 - p_a1_bf))) p_a2_bf = (3 / 4) * (2 / 3) np.testing.assert_allclose( counts["a2", "bf"], trials * p_a2_bf, atol=np.sqrt(trials * p_a2_bf * (1 - p_a2_bf))) p_af_b1 = 1 / 4 np.testing.assert_allclose( counts["af", "b1"], trials * p_af_b1, atol=np.sqrt(trials * p_af_b1 * (1 - p_af_b1)))
def make_dataflow_fns_distribution(rng, weights_temperature=0, max_depth_expected=3, max_depth_maximum=3): """Randomly sample a refinement distribution. Args: rng: Random number generator to use. weights_temperature: Dirichlet temperature to use when adjusting weights. max_depth_expected: Expected value of maximum expression nesting depth. max_depth_maximum: Maximum value of maximum expression nesting depth. Returns: A refinement distribution for examples. """ if rng: max_depth = rng.binomial(max_depth_maximum, max_depth_expected / max_depth_maximum) else: assert weights_temperature == 0 assert max_depth_expected == max_depth_maximum max_depth = max_depth_maximum groups = [ [ # Numbers top_down_refinement.WeightedTemplate(NameReferenceTemplate(), weight=10), top_down_refinement.WeightedTemplate(ConstIntTemplate(), weight=2), top_down_refinement.WeightedTemplate( BinOpTemplate(max_depth=max_depth), weight=5), top_down_refinement.WeightedTemplate(FunctionCallTemplate( num_args=1, names=["foo_1", "bar_1"], max_depth=max_depth), weight=3), top_down_refinement.WeightedTemplate(FunctionCallTemplate( num_args=2, names=["foo_2", "bar_2"], max_depth=max_depth), weight=2), top_down_refinement.WeightedTemplate(FunctionCallTemplate( num_args=4, names=["foo_4", "bar_4"], max_depth=max_depth), weight=1), ], [ # Bools top_down_refinement.WeightedTemplate(CompareTemplate(), weight=10), top_down_refinement.WeightedTemplate( BoolOpTemplate(max_depth=max_depth), weight=3), top_down_refinement.WeightedTemplate(ConstBoolTemplate(), weight=2), ], [ # Statements top_down_refinement.WeightedTemplate(AssignExistingTemplate(), weight=20), top_down_refinement.WeightedTemplate(PassTemplate(), weight=1), top_down_refinement.WeightedTemplate(PrintNumberTemplate(), weight=5), top_down_refinement.WeightedTemplate(IfBlockTemplate(), weight=2), top_down_refinement.WeightedTemplate(IfElseBlockTemplate(), weight=2), top_down_refinement.WeightedTemplate(ForRangeBlockTemplate(), weight=2), top_down_refinement.WeightedTemplate(WhileBlockTemplate(), weight=2), ], [ # Blocks top_down_refinement.WeightedTemplate(ReturnNothingTemplate(), weight=3), top_down_refinement.WeightedTemplate(ReturnNumberTemplate(), weight=3), top_down_refinement.WeightedTemplate(BreakTemplate(), weight=10), top_down_refinement.WeightedTemplate(ContinueTemplate(), weight=6), top_down_refinement.WeightedTemplate(FallthroughTemplate(), weight=40), ], [ # Nonempty statement sequences top_down_refinement.WeightedTemplate(NewAssignTemplate(), weight=5), top_down_refinement.WeightedTemplate(NormalStatementTemplate(), weight=15), ] ] weighted_templates = [ # Possibly empty statement sequences top_down_refinement.WeightedTemplate(SomeStatementsTemplate(), weight=1), top_down_refinement.WeightedTemplate(NoMoreStatementsTemplate(), weight=1, precedence=0), ] for group in groups: weights = np.array([template.weight for template in group]) weights = weights / np.sum(weights) if rng and weights_temperature > 0: weights = np.random.dirichlet(weights / weights_temperature) weighted_templates.extend( dataclasses.replace(template, weight=weight) for template, weight in zip(group, weights)) return top_down_refinement.RefinementDistribution( hole_selection_weights={ ASTHoleType.NUMBER: 3, ASTHoleType.BOOL: 10, ASTHoleType.STMT: 100, ASTHoleType.BLOCK: 10, ASTHoleType.STMTS: 1, ASTHoleType.STMTS_NONEMPTY: 100, }, weighted_templates=weighted_templates, )
CFG_DISTRIBUTION = top_down_refinement.RefinementDistribution( hole_selection_weights={ ASTHoleType.NUMBER: 3, ASTHoleType.BOOL: 10, ASTHoleType.STMT: 100, ASTHoleType.BLOCK: 10, ASTHoleType.STMTS: 1, ASTHoleType.STMTS_NONEMPTY: 100, }, weighted_templates=[ # Numbers top_down_refinement.WeightedTemplate(NameReferenceTemplate(), weight=10), top_down_refinement.WeightedTemplate(ConstIntTemplate(), weight=10), top_down_refinement.WeightedTemplate(BinOpTemplate(), weight=10), # Bools top_down_refinement.WeightedTemplate(CompareTemplate(), weight=10), top_down_refinement.WeightedTemplate(BoolOpTemplate(), weight=3), top_down_refinement.WeightedTemplate(ConstBoolTemplate(), weight=2), # Statements top_down_refinement.WeightedTemplate(AssignExistingTemplate(), weight=10), top_down_refinement.WeightedTemplate(PassTemplate(), weight=1), top_down_refinement.WeightedTemplate(PrintNumberTemplate(), weight=10), top_down_refinement.WeightedTemplate(IfBlockTemplate(), weight=5), top_down_refinement.WeightedTemplate(IfElseBlockTemplate(), weight=5), top_down_refinement.WeightedTemplate(ForRangeBlockTemplate(), weight=5), top_down_refinement.WeightedTemplate(WhileBlockTemplate(), weight=3), # Blocks top_down_refinement.WeightedTemplate(ReturnNothingTemplate(), weight=5), top_down_refinement.WeightedTemplate(ReturnNumberTemplate(), weight=5), top_down_refinement.WeightedTemplate(BreakTemplate(), weight=10), top_down_refinement.WeightedTemplate(ContinueTemplate(), weight=10), top_down_refinement.WeightedTemplate(FallthroughTemplate(), weight=30), # Nonempty statement sequences top_down_refinement.WeightedTemplate(NewAssignTemplate(), weight=5), top_down_refinement.WeightedTemplate(NormalStatementTemplate(), weight=15), # Possibly empty statement sequences top_down_refinement.WeightedTemplate(SomeStatementsTemplate(), weight=1), top_down_refinement.WeightedTemplate(NoMoreStatementsTemplate(), weight=1, precedence=0), ])