def make_ast(
    target_ast_node_count,
    rng = None,
    distribution = (
        python_numbers_control_flow.DATAFLOW_FNS_DISTRIBUTION)
):
  """Generates an AST for this task.

  Args:
    target_ast_node_count: How many nodes to put in the AST.
    rng: Random state to use.
    distribution: Sampling distribution to use when building the AST. May also
      be a callable that produces a distribution given a random state.

  Returns:
    AST of a generated program.
  """

  def root_build(body):
    """Given a list of statements, puts them into a function in a module."""
    return gast.Module(
        body=[
            gast.FunctionDef(
                name="random_function",
                args=_make_arguments(
                    python_numbers_control_flow.make_name("a"),
                    python_numbers_control_flow.make_name("b")),
                body=body,
                decorator_list=[],
                returns=None,
                type_comment=None)
        ],
        type_ignores=[])

  root_template = python_numbers_control_flow.ASTWithHoles(
      cost=5,
      holes=[
          top_down_refinement.Hole(
              python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY,
              python_numbers_control_flow.ASTHoleMetadata(("a", "b"), True,
                                                          False, 0))
      ],
      build=root_build)

  if rng is None:
    rng = np.random.RandomState()

  if callable(distribution):
    distribution = distribution(rng)

  tree = top_down_refinement.top_down_construct(
      root_object=root_template,
      target_cost=target_ast_node_count,
      refinement_distribution=distribution,
      rng=rng)

  # Re-parse the tree so that it is valid. This is required for program graph
  # analysis to work.
  return gast.parse(astunparse.unparse(gast.gast_to_ast(tree)))
Exemple #2
0
  def test_buildable(self, template):
    """Test that each template can be built when given acceptable arguments."""
    rng = np.random.RandomState(1234)

    # Construct a hole that this template can always fill.
    hole = top_down_refinement.Hole(
        template.fills_type,
        python_numbers_control_flow.ASTHoleMetadata(
            names_in_scope=frozenset({"a"}),
            inside_function=True,
            inside_loop=True,
            op_depth=0))
    self.assertTrue(template.can_fill(hole))

    # Make sure we can build this object with no errors.
    filler = template.fill(hole, rng)
    dummy_values = {
        python_numbers_control_flow.ASTHoleType.NUMBER:
            (lambda: gast.Constant(value=1, kind=None)),
        python_numbers_control_flow.ASTHoleType.BOOL:
            (lambda: gast.Constant(value=True, kind=None)),
        python_numbers_control_flow.ASTHoleType.STMT: gast.Pass,
        python_numbers_control_flow.ASTHoleType.STMTS: (lambda: []),
        python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY:
            (lambda: [gast.Pass()]),
        python_numbers_control_flow.ASTHoleType.BLOCK: (lambda: [gast.Pass()]),
    }
    hole_values = [dummy_values[h.hole_type]() for h in filler.holes]
    value = filler.build(*hole_values)

    # Check the type of the value that was built.
    if template.fills_type in (
        python_numbers_control_flow.ASTHoleType.STMTS_NONEMPTY,
        python_numbers_control_flow.ASTHoleType.BLOCK):
      self.assertTrue(value)
      for item in value:
        self.assertIsInstance(item, gast.stmt)
    elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMTS:
      for item in value:
        self.assertIsInstance(item, gast.stmt)
    elif template.fills_type == python_numbers_control_flow.ASTHoleType.STMT:
      self.assertIsInstance(value, gast.stmt)
    elif template.fills_type in (python_numbers_control_flow.ASTHoleType.NUMBER,
                                 python_numbers_control_flow.ASTHoleType.BOOL):
      self.assertIsInstance(value, gast.expr)
    else:
      raise NotImplementedError(f"Unexpected fill type {template.fills_type}; "
                                "please update this test.")

    # Check that cost reflects number of AST nodes.
    total_cost = 0
    if isinstance(value, gast.AST):
      for _ in gast.walk(value):
        total_cost += 1
    else:
      for item in value:
        for _ in gast.walk(item):
          total_cost += 1

    self.assertEqual(template.required_cost, total_cost)

    cost_without_holes = total_cost - sum(
        python_numbers_control_flow.ALL_COSTS[h.hole_type]
        for h in filler.holes)

    self.assertEqual(filler.cost, cost_without_holes)

    # Check determinism
    for _ in range(20):
      rng = np.random.RandomState(1234)
      redo_value = template.fill(hole, rng).build(*hole_values)
      if isinstance(value, list):
        self.assertEqual([gast.dump(v) for v in value],
                         [gast.dump(v) for v in redo_value])
      else:
        self.assertEqual(gast.dump(value), gast.dump(redo_value))