예제 #1
0
 def fill(self, hole, rng):
   return top_down_refinement.ThingWithHoles(
       1, [top_down_refinement.Hole("thing", None)], lambda t: "a" + t)
예제 #2
0
  def test_buildable(self, template):
    """Test that each template can be built when given acceptable arguments."""
    rng = np.random.RandomState(1234)

    # Construct a hole that this template can always fill.
    hole = top_down_refinement.Hole(
        template.fills_type,
        python_numbers_control_flow.ASTHoleMetadata(
            names_in_scope=frozenset({"a"}),
            inside_function=True,
            inside_loop=True,
            op_depth=0))
    self.assertTrue(template.can_fill(hole))

    # Make sure we can build this object with no errors.
    filler = template.fill(hole, rng)
    dummy_values = {
        python_numbers_control_flow.ASTHoleType.NUMBER:
            (lambda: gast.Constant(value=1, kind=None)),
        python_numbers_control_flow.ASTHoleType.BOOL:
            (lambda: gast.Constant(value=True, kind=None)),
        python_numbers_control_flow.ASTHoleType.STMT: gast.Pass,
        python_numbers_control_flow.ASTHoleType.STMTS: (lambda: []),
        python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY:
            (lambda: [gast.Pass()]),
        python_numbers_control_flow.ASTHoleType.BLOCK: (lambda: [gast.Pass()]),
    }
    hole_values = [dummy_values[h.hole_type]() for h in filler.holes]
    value = filler.build(*hole_values)

    # Check the type of the value that was built.
    if template.fills_type in (
        python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY,
        python_numbers_control_flow.ASTHoleType.BLOCK):
      self.assertTrue(value)
      for item in value:
        self.assertIsInstance(item, gast.stmt)
    elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMTS:
      for item in value:
        self.assertIsInstance(item, gast.stmt)
    elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMT:
      self.assertIsInstance(value, gast.stmt)
    elif template.fills_type in (python_numbers_control_flow.ASTHoleType.NUMBER,
                                 python_numbers_control_flow.ASTHoleType.BOOL):
      self.assertIsInstance(value, gast.expr)
    else:
      raise NotImplementedError(f"Unexpected fill type {template.fills_type}; "
                                "please update this test.")

    # Check that cost reflects number of AST nodes.
    total_cost = 0
    if isinstance(value, gast.AST):
      for _ in gast.walk(value):
        total_cost += 1
    else:
      for item in value:
        for _ in gast.walk(item):
          total_cost += 1

    self.assertEqual(template.required_cost, total_cost)

    cost_without_holes = total_cost - sum(
        python_numbers_control_flow.ALL_COSTS[h.hole_type]
        for h in filler.holes)

    self.assertEqual(filler.cost, cost_without_holes)

    # Check determinism
    for _ in range(20):
      rng = np.random.RandomState(1234)
      redo_value = template.fill(hole, rng).build(*hole_values)
      if isinstance(value, list):
        self.assertEqual([gast.dump(v) for v in value],
                         [gast.dump(v) for v in redo_value])
      else:
        self.assertEqual(gast.dump(value), gast.dump(redo_value))
예제 #3
0
 def fill(self, hole, rng):
   return top_down_refinement.ThingWithHoles(
       1, [top_down_refinement.Hole("bar", None)], lambda bar: "foo" + bar)
예제 #4
0
    def test_random_sampling(self):
        """Test that holes and templates are chosen proportional to weights."""
        class A1Template(top_down_refinement.HoleFillerTemplate):
            fills_type = "a"
            required_cost = 2

            def fill(self, hole, rng):
                return top_down_refinement.ThingWithHoles(2, [], lambda: "a1")

        class A2Template(top_down_refinement.HoleFillerTemplate):
            fills_type = "a"
            required_cost = 2

            def fill(self, hole, rng):
                return top_down_refinement.ThingWithHoles(2, [], lambda: "a2")

        class AFallbackTemplate(top_down_refinement.HoleFillerTemplate):
            fills_type = "a"
            required_cost = 1

            def fill(self, hole, rng):
                return top_down_refinement.ThingWithHoles(2, [], lambda: "af")

        class B1Template(top_down_refinement.HoleFillerTemplate):
            fills_type = "b"
            required_cost = 2

            def fill(self, hole, rng):
                return top_down_refinement.ThingWithHoles(2, [], lambda: "b1")

        class BFallbackTemplate(top_down_refinement.HoleFillerTemplate):
            fills_type = "b"
            required_cost = 1

            def fill(self, hole, rng):
                return top_down_refinement.ThingWithHoles(1, [], lambda: "bf")

        counts = collections.Counter()
        rng = np.random.RandomState(1234)
        trials = 10000
        for _ in range(trials):
            result = top_down_refinement.top_down_construct(
                root_object=top_down_refinement.ThingWithHoles(
                    0, [
                        top_down_refinement.Hole("a", None),
                        top_down_refinement.Hole("b", None)
                    ], lambda a, b: (a, b)),
                target_cost=3,
                refinement_distribution=top_down_refinement.
                RefinementDistribution(weighted_templates=[
                    top_down_refinement.WeightedTemplate(A1Template(), 1),
                    top_down_refinement.WeightedTemplate(A2Template(), 2),
                    top_down_refinement.WeightedTemplate(AFallbackTemplate(),
                                                         1,
                                                         precedence=0),
                    top_down_refinement.WeightedTemplate(B1Template(), 1),
                    top_down_refinement.WeightedTemplate(BFallbackTemplate(),
                                                         1,
                                                         precedence=0),
                ],
                                       hole_selection_weights={
                                           "a": 3,
                                           "b": 1
                                       }),
                rng=rng)
            counts[result] += 1

        # Assert that counts are within one standard deviation of the mean (which is
        # sufficient for the fixed seed above).
        p_a1_bf = (3 / 4) * (1 / 3)
        np.testing.assert_allclose(counts["a1", "bf"],
                                   trials * p_a1_bf,
                                   atol=np.sqrt(trials * p_a1_bf *
                                                (1 - p_a1_bf)))

        p_a2_bf = (3 / 4) * (2 / 3)
        np.testing.assert_allclose(counts["a2", "bf"],
                                   trials * p_a2_bf,
                                   atol=np.sqrt(trials * p_a2_bf *
                                                (1 - p_a2_bf)))

        p_af_b1 = 1 / 4
        np.testing.assert_allclose(counts["af", "b1"],
                                   trials * p_af_b1,
                                   atol=np.sqrt(trials * p_af_b1 *
                                                (1 - p_af_b1)))