Beispiel #1
0
  def test_cost_and_precedence(self):
    """Test that we use highest-precedence rules unless constrained by cost."""

    class BigTemplate(top_down_refinement.HoleFillerTemplate):
      fills_type = "thing"
      required_cost = 2

      def fill(self, hole, rng):
        return top_down_refinement.ThingWithHoles(
            1, [top_down_refinement.Hole("thing", None)], lambda t: "a" + t)

    class SmallTemplate(top_down_refinement.HoleFillerTemplate):
      fills_type = "thing"
      required_cost = 1

      def fill(self, hole, rng):
        return top_down_refinement.ThingWithHoles(1, [], lambda: "b")

    result = top_down_refinement.top_down_construct(
        root_object=top_down_refinement.ThingWithHoles(
            0, [top_down_refinement.Hole("thing", None)], lambda t: t),
        target_cost=10,
        refinement_distribution=top_down_refinement.RefinementDistribution(
            weighted_templates=[
                top_down_refinement.WeightedTemplate(
                    BigTemplate(), 1, precedence=1),
                top_down_refinement.WeightedTemplate(
                    SmallTemplate(), 10000, precedence=0),
            ],
            hole_selection_weights={"thing": 1}))

    self.assertEqual(result, "aaaaaaaaab")
Beispiel #2
0
  def test_construct_simple(self):
    """Simple test that we can fill different hole types."""

    class FooTemplate(top_down_refinement.HoleFillerTemplate):
      fills_type = "foo"
      required_cost = 2

      def fill(self, hole, rng):
        return top_down_refinement.ThingWithHoles(
            1, [top_down_refinement.Hole("bar", None)], lambda bar: "foo" + bar)

    class BarTemplate(top_down_refinement.HoleFillerTemplate):
      fills_type = "bar"
      required_cost = 1

      def fill(self, hole, rng):
        return top_down_refinement.ThingWithHoles(1, [], lambda: "bar")

    result = top_down_refinement.top_down_construct(
        root_object=top_down_refinement.ThingWithHoles(
            0, [top_down_refinement.Hole("foo", None)], lambda foo: foo),
        target_cost=2,
        refinement_distribution=top_down_refinement.RefinementDistribution(
            weighted_templates=[
                top_down_refinement.WeightedTemplate(FooTemplate(), 1),
                top_down_refinement.WeightedTemplate(BarTemplate(), 1),
            ],
            hole_selection_weights={
                "foo": 1,
                "bar": 1
            }))

    self.assertEqual(result, "foobar")
Beispiel #3
0
  def test_random_sampling(self):
    """Test that holes and templates are chosen proportional to weights."""

    class A1Template(top_down_refinement.HoleFillerTemplate):
      fills_type = "a"
      required_cost = 2

      def fill(self, hole, rng):
        return top_down_refinement.ThingWithHoles(2, [], lambda: "a1")

    class A2Template(top_down_refinement.HoleFillerTemplate):
      fills_type = "a"
      required_cost = 2

      def fill(self, hole, rng):
        return top_down_refinement.ThingWithHoles(2, [], lambda: "a2")

    class AFallbackTemplate(top_down_refinement.HoleFillerTemplate):
      fills_type = "a"
      required_cost = 1

      def fill(self, hole, rng):
        return top_down_refinement.ThingWithHoles(2, [], lambda: "af")

    class B1Template(top_down_refinement.HoleFillerTemplate):
      fills_type = "b"
      required_cost = 2

      def fill(self, hole, rng):
        return top_down_refinement.ThingWithHoles(2, [], lambda: "b1")

    class BFallbackTemplate(top_down_refinement.HoleFillerTemplate):
      fills_type = "b"
      required_cost = 1

      def fill(self, hole, rng):
        return top_down_refinement.ThingWithHoles(1, [], lambda: "bf")

    counts = collections.Counter()
    rng = np.random.RandomState(1234)
    trials = 10000
    for _ in range(trials):
      result = top_down_refinement.top_down_construct(
          root_object=top_down_refinement.ThingWithHoles(
              0, [
                  top_down_refinement.Hole("a", None),
                  top_down_refinement.Hole("b", None)
              ], lambda a, b: (a, b)),
          target_cost=3,
          refinement_distribution=top_down_refinement.RefinementDistribution(
              weighted_templates=[
                  top_down_refinement.WeightedTemplate(A1Template(), 1),
                  top_down_refinement.WeightedTemplate(A2Template(), 2),
                  top_down_refinement.WeightedTemplate(
                      AFallbackTemplate(), 1, precedence=0),
                  top_down_refinement.WeightedTemplate(B1Template(), 1),
                  top_down_refinement.WeightedTemplate(
                      BFallbackTemplate(), 1, precedence=0),
              ],
              hole_selection_weights={
                  "a": 3,
                  "b": 1
              }),
          rng=rng)
      counts[result] += 1

    # Assert that counts are within one standard deviation of the mean (which is
    # sufficient for the fixed seed above).
    p_a1_bf = (3 / 4) * (1 / 3)
    np.testing.assert_allclose(
        counts["a1", "bf"],
        trials * p_a1_bf,
        atol=np.sqrt(trials * p_a1_bf * (1 - p_a1_bf)))

    p_a2_bf = (3 / 4) * (2 / 3)
    np.testing.assert_allclose(
        counts["a2", "bf"],
        trials * p_a2_bf,
        atol=np.sqrt(trials * p_a2_bf * (1 - p_a2_bf)))

    p_af_b1 = 1 / 4
    np.testing.assert_allclose(
        counts["af", "b1"],
        trials * p_af_b1,
        atol=np.sqrt(trials * p_af_b1 * (1 - p_af_b1)))
def make_dataflow_fns_distribution(rng,
                                   weights_temperature=0,
                                   max_depth_expected=3,
                                   max_depth_maximum=3):
    """Randomly sample a refinement distribution.

  Args:
    rng: Random number generator to use.
    weights_temperature: Dirichlet temperature to use when adjusting weights.
    max_depth_expected: Expected value of maximum expression nesting depth.
    max_depth_maximum: Maximum value of maximum expression nesting depth.

  Returns:
    A refinement distribution for examples.
  """
    if rng:
        max_depth = rng.binomial(max_depth_maximum,
                                 max_depth_expected / max_depth_maximum)
    else:
        assert weights_temperature == 0
        assert max_depth_expected == max_depth_maximum
        max_depth = max_depth_maximum

    groups = [
        [  # Numbers
            top_down_refinement.WeightedTemplate(NameReferenceTemplate(),
                                                 weight=10),
            top_down_refinement.WeightedTemplate(ConstIntTemplate(), weight=2),
            top_down_refinement.WeightedTemplate(
                BinOpTemplate(max_depth=max_depth), weight=5),
            top_down_refinement.WeightedTemplate(FunctionCallTemplate(
                num_args=1, names=["foo_1", "bar_1"], max_depth=max_depth),
                                                 weight=3),
            top_down_refinement.WeightedTemplate(FunctionCallTemplate(
                num_args=2, names=["foo_2", "bar_2"], max_depth=max_depth),
                                                 weight=2),
            top_down_refinement.WeightedTemplate(FunctionCallTemplate(
                num_args=4, names=["foo_4", "bar_4"], max_depth=max_depth),
                                                 weight=1),
        ],
        [  # Bools
            top_down_refinement.WeightedTemplate(CompareTemplate(), weight=10),
            top_down_refinement.WeightedTemplate(
                BoolOpTemplate(max_depth=max_depth), weight=3),
            top_down_refinement.WeightedTemplate(ConstBoolTemplate(),
                                                 weight=2),
        ],
        [  # Statements
            top_down_refinement.WeightedTemplate(AssignExistingTemplate(),
                                                 weight=20),
            top_down_refinement.WeightedTemplate(PassTemplate(), weight=1),
            top_down_refinement.WeightedTemplate(PrintNumberTemplate(),
                                                 weight=5),
            top_down_refinement.WeightedTemplate(IfBlockTemplate(), weight=2),
            top_down_refinement.WeightedTemplate(IfElseBlockTemplate(),
                                                 weight=2),
            top_down_refinement.WeightedTemplate(ForRangeBlockTemplate(),
                                                 weight=2),
            top_down_refinement.WeightedTemplate(WhileBlockTemplate(),
                                                 weight=2),
        ],
        [  # Blocks
            top_down_refinement.WeightedTemplate(ReturnNothingTemplate(),
                                                 weight=3),
            top_down_refinement.WeightedTemplate(ReturnNumberTemplate(),
                                                 weight=3),
            top_down_refinement.WeightedTemplate(BreakTemplate(), weight=10),
            top_down_refinement.WeightedTemplate(ContinueTemplate(), weight=6),
            top_down_refinement.WeightedTemplate(FallthroughTemplate(),
                                                 weight=40),
        ],
        [  # Nonempty statement sequences
            top_down_refinement.WeightedTemplate(NewAssignTemplate(),
                                                 weight=5),
            top_down_refinement.WeightedTemplate(NormalStatementTemplate(),
                                                 weight=15),
        ]
    ]
    weighted_templates = [
        # Possibly empty statement sequences
        top_down_refinement.WeightedTemplate(SomeStatementsTemplate(),
                                             weight=1),
        top_down_refinement.WeightedTemplate(NoMoreStatementsTemplate(),
                                             weight=1,
                                             precedence=0),
    ]
    for group in groups:
        weights = np.array([template.weight for template in group])
        weights = weights / np.sum(weights)
        if rng and weights_temperature > 0:
            weights = np.random.dirichlet(weights / weights_temperature)
        weighted_templates.extend(
            dataclasses.replace(template, weight=weight)
            for template, weight in zip(group, weights))

    return top_down_refinement.RefinementDistribution(
        hole_selection_weights={
            ASTHoleType.NUMBER: 3,
            ASTHoleType.BOOL: 10,
            ASTHoleType.STMT: 100,
            ASTHoleType.BLOCK: 10,
            ASTHoleType.STMTS: 1,
            ASTHoleType.STMTS_NONEMPTY: 100,
        },
        weighted_templates=weighted_templates,
    )
CFG_DISTRIBUTION = top_down_refinement.RefinementDistribution(
    hole_selection_weights={
        ASTHoleType.NUMBER: 3,
        ASTHoleType.BOOL: 10,
        ASTHoleType.STMT: 100,
        ASTHoleType.BLOCK: 10,
        ASTHoleType.STMTS: 1,
        ASTHoleType.STMTS_NONEMPTY: 100,
    },
    weighted_templates=[
        # Numbers
        top_down_refinement.WeightedTemplate(NameReferenceTemplate(),
                                             weight=10),
        top_down_refinement.WeightedTemplate(ConstIntTemplate(), weight=10),
        top_down_refinement.WeightedTemplate(BinOpTemplate(), weight=10),
        # Bools
        top_down_refinement.WeightedTemplate(CompareTemplate(), weight=10),
        top_down_refinement.WeightedTemplate(BoolOpTemplate(), weight=3),
        top_down_refinement.WeightedTemplate(ConstBoolTemplate(), weight=2),
        # Statements
        top_down_refinement.WeightedTemplate(AssignExistingTemplate(),
                                             weight=10),
        top_down_refinement.WeightedTemplate(PassTemplate(), weight=1),
        top_down_refinement.WeightedTemplate(PrintNumberTemplate(), weight=10),
        top_down_refinement.WeightedTemplate(IfBlockTemplate(), weight=5),
        top_down_refinement.WeightedTemplate(IfElseBlockTemplate(), weight=5),
        top_down_refinement.WeightedTemplate(ForRangeBlockTemplate(),
                                             weight=5),
        top_down_refinement.WeightedTemplate(WhileBlockTemplate(), weight=3),
        # Blocks
        top_down_refinement.WeightedTemplate(ReturnNothingTemplate(),
                                             weight=5),
        top_down_refinement.WeightedTemplate(ReturnNumberTemplate(), weight=5),
        top_down_refinement.WeightedTemplate(BreakTemplate(), weight=10),
        top_down_refinement.WeightedTemplate(ContinueTemplate(), weight=10),
        top_down_refinement.WeightedTemplate(FallthroughTemplate(), weight=30),
        # Nonempty statement sequences
        top_down_refinement.WeightedTemplate(NewAssignTemplate(), weight=5),
        top_down_refinement.WeightedTemplate(NormalStatementTemplate(),
                                             weight=15),
        # Possibly empty statement sequences
        top_down_refinement.WeightedTemplate(SomeStatementsTemplate(),
                                             weight=1),
        top_down_refinement.WeightedTemplate(NoMoreStatementsTemplate(),
                                             weight=1,
                                             precedence=0),
    ])