def test_variants_from_edges(self):
    example = graph_bundle.zeros_like_padded_example(
        graph_bundle.PaddingConfig(
            static_max_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=5, num_input_tagged_nodes=0),
            max_initial_transitions=0,
            max_in_tagged_transitions=0,
            max_edges=8))
    example = dataclasses.replace(
        example,
        graph_metadata=automaton_builder.EncodedGraphMetadata(
            num_nodes=4, num_input_tagged_nodes=0),
        edges=sparse_operator.SparseCoordOperator(
            input_indices=jnp.array([[0], [0], [0], [1], [1], [2], [0], [0]]),
            output_indices=jnp.array([[1, 2], [2, 3], [3, 0], [2, 0], [0, 2],
                                      [0, 3], [0, 0], [0, 0]]),
            values=jnp.array([1, 1, 1, 1, 1, 1, 0, 0])))

    weights = edge_supervision_models.variants_from_edges(
        example,
        automaton_builder.EncodedGraphMetadata(
            num_nodes=5, num_input_tagged_nodes=0),
        variant_edge_type_indices=[2, 0],
        num_edge_types=3)
    expected = np.array([
        [[1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0]],
        [[1, 0, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]],
        [[1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 0, 1]],
        [[0, 0, 1], [1, 0, 0], [1, 0, 0], [1, 0, 0]],
    ], np.float32)
    # Only assert on the non-padded part.
    np.testing.assert_allclose(weights[:4, :4], expected)
Пример #2
0
 def test_one_node_particle_estimate_padding(self):
   schema, graph = self.build_doubly_linked_list_graph(4)
   builder = automaton_builder.AutomatonBuilder(schema)
   enc_graph, enc_meta = builder.encode_graph(graph)
   enc_graph_padded = automaton_builder.EncodedGraph(
       initial_to_in_tagged=enc_graph.initial_to_in_tagged.pad_nonzeros(64),
       initial_to_special=jax_util.pad_to(enc_graph.initial_to_special, 64),
       in_tagged_to_in_tagged=(
           enc_graph.in_tagged_to_in_tagged.pad_nonzeros(64)),
       in_tagged_to_special=(jax_util.pad_to(enc_graph.in_tagged_to_special,
                                             64)),
       in_tagged_node_indices=(jax_util.pad_to(
           enc_graph.in_tagged_node_indices, 64)))
   enc_meta_padded = automaton_builder.EncodedGraphMetadata(
       num_nodes=64, num_input_tagged_nodes=64)
   variant_weights = jnp.full([64, 5], 0.2)
   routing_params = automaton_builder.RoutingParams(
       move=jnp.full([5, 6, 2, 2], 0.2), special=jnp.full([5, 3, 2, 3], 0.2))
   tmat = builder.build_transition_matrix(routing_params, enc_graph_padded,
                                          enc_meta_padded)
   outs = automaton_sampling.one_node_particle_estimate(
       builder,
       tmat,
       variant_weights,
       start_machine_state=jnp.array([1., 0.]),
       node_index=0,
       steps=100,
       num_rollouts=100,
       max_possible_transitions=2,
       num_valid_nodes=enc_meta.num_nodes,
       rng=jax.random.PRNGKey(0))
   self.assertEqual(outs.shape, (64,))
   self.assertTrue(jnp.all(outs[:enc_meta.num_nodes] > 0))
   self.assertTrue(jnp.all(outs[enc_meta.num_nodes:] == 0))
    def test_calibrate_padding(self):
        # Make sure padding calibration doesn't error out, so that it works when
        # run interactively.
        def build_example(size):
            tree = gast.Module(
                body=[gast.Constant(value=i, kind=None) for i in range(size)],
                type_ignores=[])
            py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree))
            edges = []
            for i in range(1, size, 2):
                edges.append((ast_to_node_id[id(tree.body[i])],
                              ast_to_node_id[id(tree.body[i - 1])], 1))
            return graph_bundle.convert_graph_with_edges(
                py_graph, edges, py_ast_graphs.BUILDER)

        padding_calibration.calibrate_padding(
            example_builder=build_example,
            desired_sizes=graph_bundle.PaddingConfig(
                static_max_metadata=automaton_builder.EncodedGraphMetadata(
                    num_nodes=64, num_input_tagged_nodes=64),
                max_initial_transitions=128,
                max_in_tagged_transitions=256,
                max_edges=64,
            ),
            samples=50,
            optimization_max_steps=500,
            round_to_powers_of_two=True)
 def _make_example(self):
   example = graph_bundle.zeros_like_padded_example(
       graph_bundle.PaddingConfig(
           static_max_metadata=automaton_builder.EncodedGraphMetadata(
               num_nodes=5, num_input_tagged_nodes=0),
           max_initial_transitions=0,
           max_in_tagged_transitions=0,
           max_edges=8))
   example = dataclasses.replace(
       example,
       graph_metadata=automaton_builder.EncodedGraphMetadata(
           num_nodes=4, num_input_tagged_nodes=0),
       edges=sparse_operator.SparseCoordOperator(
           input_indices=jnp.array([[0], [0], [0], [0], [1], [2], [2], [0]]),
           output_indices=jnp.array([[1, 2], [2, 3], [2, 2], [3, 0], [0, 2],
                                     [0, 3], [0, 0], [0, 0]]),
           values=jnp.array([1, 1, 1, 1, 1, 1, 0, 0])))
   return example
Пример #5
0
def zeros_like_padded_example(config):
    """Build an GraphBundle containing only zeros.

  This can be useful to initialize model parameters, or do tests.

  Args:
    config: Configuration specifying the desired padding size.

  Returns:
    An "example" filled with zeros of the given size.
  """
    return GraphBundle(
        automaton_graph=automaton_builder.EncodedGraph(
            initial_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=np.zeros(shape=(config.max_initial_transitions,
                                              1),
                                       dtype=np.int32),
                output_indices=np.zeros(shape=(config.max_initial_transitions,
                                               2),
                                        dtype=np.int32),
                values=np.zeros(shape=(config.max_initial_transitions, ),
                                dtype=np.float32),
            ),
            initial_to_special=np.zeros(
                shape=(config.static_max_metadata.num_nodes, ),
                dtype=np.int32),
            in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=np.zeros(shape=(config.max_in_tagged_transitions,
                                              1),
                                       dtype=np.int32),
                output_indices=np.zeros(
                    shape=(config.max_in_tagged_transitions, 2),
                    dtype=np.int32),
                values=np.zeros(shape=(config.max_in_tagged_transitions, ),
                                dtype=np.float32),
            ),
            in_tagged_to_special=np.zeros(
                shape=(config.static_max_metadata.num_input_tagged_nodes, ),
                dtype=np.int32),
            in_tagged_node_indices=np.zeros(
                shape=(config.static_max_metadata.num_input_tagged_nodes, ),
                dtype=np.int32),
        ),
        graph_metadata=automaton_builder.EncodedGraphMetadata(
            num_nodes=0, num_input_tagged_nodes=0),
        node_types=np.zeros(shape=(config.static_max_metadata.num_nodes, ),
                            dtype=np.int32),
        edges=sparse_operator.SparseCoordOperator(
            input_indices=np.zeros(shape=(config.max_edges, 1),
                                   dtype=np.int32),
            output_indices=np.zeros(shape=(config.max_edges, 2),
                                    dtype=np.int32),
            values=np.zeros(shape=(config.max_edges, ), dtype=np.int32),
        ),
    )
 def apply(self, dummy_ignored):
   abstract_encoded_graph = jax.tree_map(
       lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph)
   abstract_variant_weights = jax.tree_map(
       lambda y: jax.lax.tie_in(dummy_ignored, y), variant_weights())
   return automaton_layer.FiniteStateGraphAutomaton(
       encoded_graph=abstract_encoded_graph,
       variant_weights=abstract_variant_weights,
       dynamic_metadata=automaton_builder.EncodedGraphMetadata(
           num_nodes=32, num_input_tagged_nodes=64),
       static_metadata=automaton_builder.EncodedGraphMetadata(
           num_nodes=32, num_input_tagged_nodes=64),
       builder=builder,
       num_out_edges=3,
       num_intermediate_states=4,
       share_states_across_edges=shared,
       use_gate_parameterization=use_gate,
       estimator_type=estimator_type,
       name="the_layer",
       **kwargs)
Пример #7
0
def build_padding_config(log2_num_nodes, log2_num_input_tagged_nodes,
                         log2_max_initial_transitions,
                         log2_max_in_tagged_transitions, log2_max_edges,
                         log2_max_tokens):
    """Builds a padding config with power-of-2 sizes."""
    return example_definition.GraphBundleWithTokensPaddingConfig(
        bundle_padding=graph_bundle.PaddingConfig(
            static_max_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=2**log2_num_nodes,
                num_input_tagged_nodes=2**log2_num_input_tagged_nodes),
            max_initial_transitions=2**log2_max_initial_transitions,
            max_in_tagged_transitions=2**log2_max_in_tagged_transitions,
            max_edges=2**log2_max_edges),
        max_tokens=2**log2_max_tokens)
Пример #8
0
  def test_component_shapes(self,
                            component,
                            embed_edges,
                            expected_dims,
                            extra_config=None):
    gin.clear_config()
    gin.parse_config(CONFIG)
    if extra_config:
      gin.parse_config(extra_config)

    # Run the computation with placeholder inputs.
    (node_out, edge_out), _ = end_to_end_stack.ALL_COMPONENTS[component].init(
        jax.random.PRNGKey(0),
        graph_context=end_to_end_stack.SharedGraphContext(
            bundle=graph_bundle.zeros_like_padded_example(
                graph_bundle.PaddingConfig(
                    static_max_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=16, num_input_tagged_nodes=32),
                    max_initial_transitions=11,
                    max_in_tagged_transitions=12,
                    max_edges=13)),
            static_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=16, num_input_tagged_nodes=32),
            edge_types_to_indices={"foo": 0},
            builder=automaton_builder.AutomatonBuilder({
                graph_types.NodeType("node"):
                    graph_types.NodeSchema(
                        in_edges=[graph_types.InEdgeType("in")],
                        out_edges=[graph_types.InEdgeType("out")])
            }),
            edges_are_embedded=embed_edges),
        node_embeddings=jnp.zeros((16, NODE_DIM)),
        edge_embeddings=jnp.zeros((16, 16, EDGE_DIM)))

    self.assertEqual(node_out.shape, (16, expected_dims["node"]))
    self.assertEqual(edge_out.shape, (16, 16, expected_dims["edge"]))
Пример #9
0
    def test_zeros_like_padded_example(self):
        tree = gast.parse("pass")
        py_graph, _ = py_ast_graphs.py_ast_to_graph(tree)
        example = graph_bundle.convert_graph_with_edges(
            py_graph, [], builder=py_ast_graphs.BUILDER)

        padding_config = graph_bundle.PaddingConfig(
            static_max_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=16, num_input_tagged_nodes=34),
            max_initial_transitions=64,
            max_in_tagged_transitions=128,
            max_edges=4)

        padded_example = graph_bundle.pad_example(example, padding_config)
        generated = graph_bundle.zeros_like_padded_example(padding_config)

        def _check(x, y):
            x = np.asarray(x)
            y = np.asarray(y)
            self.assertEqual(x.shape, y.shape)
            self.assertEqual(x.dtype, y.dtype)

        jax.tree_multimap(_check, generated, padded_example)
      bundle_padding=graph_bundle.PaddingConfig(
          static_max_metadata=automaton_builder.EncodedGraphMetadata(
              num_nodes=2**log2_num_nodes,
              num_input_tagged_nodes=2**log2_num_input_tagged_nodes),
          max_initial_transitions=2**log2_max_initial_transitions,
          max_in_tagged_transitions=2**log2_max_in_tagged_transitions,
          max_edges=2**log2_max_edges),
      max_tokens=2**log2_max_tokens)


PaddingAndBatchSizes = (
    List[Tuple[example_definition.GraphBundleWithTokensPaddingConfig, int]])

TINY_PADDING_CONFIG = example_definition.GraphBundleWithTokensPaddingConfig(
    bundle_padding=graph_bundle.PaddingConfig(
        static_max_metadata=automaton_builder.EncodedGraphMetadata(
            num_nodes=1, num_input_tagged_nodes=1),
        max_initial_transitions=1,
        max_in_tagged_transitions=1,
        max_edges=1),
    max_tokens=1)


def pad_and_batch_with_rng(
    it, num_devices,
    padding_and_batch_sizes,
    base_rng):
  """Pad and batch according to a collection of sizes.

  Args:
    it: Iterable over individual examples.
    num_devices: Number of devices; determines constant leading batch dimension.
Пример #11
0
    def test_convert_example(self):
        tree = gast.parse(
            textwrap.dedent("""\
          def foo():
            x = 5
            return x
          """))

        py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree))
        ast_edges = [
            (tree.body[0].body[1], tree.body[0], 1),
            (tree.body[0].body[1].value, tree.body[0].body[0].targets[0], 2),
        ]
        converted_edges = [(ast_to_node_id[id(source)],
                            ast_to_node_id[id(dest)], edge_type)
                           for (source, dest, edge_type) in ast_edges]
        example = graph_bundle.convert_graph_with_edges(
            py_graph, converted_edges, builder=py_ast_graphs.BUILDER)

        self.assertEqual(list(py_graph), [
            "root__Module",
            "root_body_0__Module_body-seq-helper",
            "root_body_0_item__FunctionDef",
            "root_body_0_item_args__arguments",
            "root_body_0_item_body_0__FunctionDef_body-seq-helper",
            "root_body_0_item_body_0_item__Assign",
            "root_body_0_item_body_0_item_targets__Name",
            "root_body_0_item_body_0_item_value__Constant",
            "root_body_0_item_body_1__FunctionDef_body-seq-helper",
            "root_body_0_item_body_1_item__Return",
            "root_body_0_item_body_1_item_value__Name",
        ])

        self.assertEqual(
            example.graph_metadata,
            automaton_builder.EncodedGraphMetadata(num_nodes=11,
                                                   num_input_tagged_nodes=27))

        self.assertEqual(example.node_types.shape, (11, ))

        np.testing.assert_array_equal(example.edges.input_indices, [[1], [2]])
        np.testing.assert_array_equal(example.edges.output_indices,
                                      [[9, 2], [10, 6]])
        np.testing.assert_array_equal(example.edges.values, [1, 1])

        self.assertEqual(
            example.automaton_graph.initial_to_in_tagged.values.shape, (34, ))
        self.assertEqual(example.automaton_graph.initial_to_special.shape,
                         (11, ))
        self.assertEqual(
            example.automaton_graph.in_tagged_to_in_tagged.values.shape,
            (103, ))
        self.assertEqual(example.automaton_graph.in_tagged_to_special.shape,
                         (27, ))

        # Verify that the transition matrix can be built with the right size.
        routing_params = py_ast_graphs.BUILDER.initialize_routing_params(
            None, 1, 1, noise_factor=0)
        transition_matrix = py_ast_graphs.BUILDER.build_transition_matrix(
            routing_params, example.automaton_graph, example.graph_metadata)

        self.assertEqual(transition_matrix.initial_to_in_tagged.shape,
                         (1, 11, 1, 27, 1))
        self.assertEqual(transition_matrix.initial_to_special.shape,
                         (1, 11, 1, 3))
        self.assertEqual(transition_matrix.in_tagged_to_in_tagged.shape,
                         (1, 27, 1, 27, 1))
        self.assertEqual(transition_matrix.in_tagged_to_special.shape,
                         (1, 27, 1, 3))
        self.assertEqual(transition_matrix.in_tagged_node_indices.shape,
                         (27, ))
Пример #12
0
    def test_pad_example(self):
        tree = gast.parse(
            textwrap.dedent("""\
          def foo():
            x = 5
            return x
          """))

        py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree))
        ast_edges = [
            (tree.body[0].body[1], tree.body[0], 1),
            (tree.body[0].body[1].value, tree.body[0].body[0].targets[0], 2),
        ]
        converted_edges = [(ast_to_node_id[id(source)],
                            ast_to_node_id[id(dest)], edge_type)
                           for (source, dest, edge_type) in ast_edges]
        example = graph_bundle.convert_graph_with_edges(
            py_graph, converted_edges, builder=py_ast_graphs.BUILDER)

        padding_config = graph_bundle.PaddingConfig(
            static_max_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=16, num_input_tagged_nodes=34),
            max_initial_transitions=64,
            max_in_tagged_transitions=128,
            max_edges=4)

        padded_example = graph_bundle.pad_example(example, padding_config)

        # Metadata is not affected by padding.
        self.assertEqual(
            padded_example.graph_metadata,
            automaton_builder.EncodedGraphMetadata(num_nodes=11,
                                                   num_input_tagged_nodes=27))

        # Everything else is padded.
        self.assertEqual(padded_example.node_types.shape, (16, ))

        np.testing.assert_array_equal(padded_example.edges.input_indices,
                                      [[1], [2], [0], [0]])
        np.testing.assert_array_equal(padded_example.edges.output_indices,
                                      [[9, 2], [10, 6], [0, 0], [0, 0]])
        np.testing.assert_array_equal(padded_example.edges.values,
                                      [1, 1, 0, 0])

        self.assertEqual(
            padded_example.automaton_graph.initial_to_in_tagged.values.shape,
            (64, ))
        self.assertEqual(
            padded_example.automaton_graph.initial_to_special.shape, (16, ))
        self.assertEqual(
            padded_example.automaton_graph.in_tagged_to_in_tagged.values.shape,
            (128, ))
        self.assertEqual(
            padded_example.automaton_graph.in_tagged_to_special.shape, (34, ))

        # Transition matrix also becomes padded once it is built.
        # (Note that we pass the padded static metadata to the transition matrix
        # builder, since the encoded graph has been padded.)
        routing_params = py_ast_graphs.BUILDER.initialize_routing_params(
            None, 1, 1, noise_factor=0)
        transition_matrix = py_ast_graphs.BUILDER.build_transition_matrix(
            routing_params, padded_example.automaton_graph,
            padding_config.static_max_metadata)

        self.assertEqual(transition_matrix.initial_to_in_tagged.shape,
                         (1, 16, 1, 34, 1))
        self.assertEqual(transition_matrix.initial_to_special.shape,
                         (1, 16, 1, 3))
        self.assertEqual(transition_matrix.in_tagged_to_in_tagged.shape,
                         (1, 34, 1, 34, 1))
        self.assertEqual(transition_matrix.in_tagged_to_special.shape,
                         (1, 34, 1, 3))
        self.assertEqual(transition_matrix.in_tagged_node_indices.shape,
                         (34, ))
Пример #13
0
    refinement_distribution: Distribution to sample from.
    padding_config: How to pad the generated examples.
  """
    target_ast_size: int
    refinement_distribution: RefinementDistnOrMetaDistn
    padding_config: graph_bundle.PaddingConfig


# Constants calibrated using `padding_calibration.calibrate_padding`.
DISTRIBUTIONS = {
    "control_flow":
    TaskExampleDistribution(
        target_ast_size=159,
        refinement_distribution=(python_numbers_control_flow.CFG_DISTRIBUTION),
        padding_config=graph_bundle.PaddingConfig(
            static_max_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=256, num_input_tagged_nodes=512),
            max_initial_transitions=1024,
            max_in_tagged_transitions=2048,
            max_edges=2048)),
    "data_flow":
    TaskExampleDistribution(
        target_ast_size=172,
        refinement_distribution=(
            python_numbers_control_flow.DATAFLOW_DISTRIBUTION),
        padding_config=graph_bundle.PaddingConfig(
            static_max_metadata=automaton_builder.EncodedGraphMetadata(
                num_nodes=256, num_input_tagged_nodes=512),
            max_initial_transitions=1024,
            max_in_tagged_transitions=2048,
            max_edges=4096)),
    "data_flow_fns":
Пример #14
0
def calibrate_padding(
    example_builder,
    desired_sizes,
    success_probability=0.9,
    samples=1000,
    min_object_size=16,
    max_object_size=512,
    optimization_max_steps=10000,
    log_likelihood_epsilon=1e-2,
    round_to_powers_of_two=False,
    progress=None,
):
    """Determine an example size that fits in a given padding config.

  (Note: This function is designed to be used interactively as part of setting
  up a new dataset.)

  The sizes of examples as generated by example generators doesn't always
  match the actual size of the example once it has been encoded. To enable
  running with static batch sizes, we want to determine the sizes of the actual
  inputs to the model. The inputs, however, depend on the schema encoding of the
  specific nodes that are chosen, and (for instance) not all AST nodes are the
  same size in the graph representation.

  This method takes a desired padding configuration, which specifies a bound
  on the sizes of the inputs that we want to get close to. It then generates
  a bunch of data, fits a simple model to the sizes of the generated data, and
  then uses that model to determine what initial object size we should generate
  to make sure that we fit within the desired padding with probability at least
  `success_probability`.

  Since some parts of the input are likely smaller than others, we also return
  a smaller padding config that still suffices to hold the generated examples.

  (More specifically, we model the size of each example as a draw from a normal
  distribution, where the mean and variance have a constant term and a term
  proportional to the initial object size. This is inspired by the central
  limit theorem, which states that the sum of IID random variables approaches
  a normal distribution with mean and variance proportional to the number of
  variables in the sum. The theorem doesn't perfectly hold in this case, since
  there are some dependencies between AST nodes, but it does seem to be a good
  approximation.)

  Args:
    example_builder: Function that, when called, returns a random example whose
      size is (roughly) proportional to the function argument.
    desired_sizes: Padding that specifies the max sizes for each dimension. We
      will attempt to get close to this without exceeding it.
    success_probability: Proportion of generated examples we want to be able to
      keep, i.e. the proportion that should be smaller than the padding size.
    samples: How many random examples to generate.
    min_object_size: Minimum size of AST to generate while fitting.
    max_object_size: Maximum size of AST to generate while fitting.
    optimization_max_steps: Max iterations to fit size model.
    log_likelihood_epsilon: Stop optimizing when the loss changes by less than
      this amount in each iteration.
    round_to_powers_of_two: Whether to automatically round up the returned
      padding sizes to powers of two.
    progress: Wrapper around an iterable to use as a progress bar, to show
      progress during training (such as tqdm)

  Returns:
    - Calibrated target number of nodes to use as a generation target.
    - Padding configuration to use.
  """
    if progress is None:
        progress = lambda x: x

    # Collect samples.
    print("Generating data...")
    object_sizes = np.empty([samples], dtype="int")
    data = {
        "graph_nodes": np.empty([samples], dtype="int"),
        "graph_in_tagged_nodes": np.empty([samples], dtype="int"),
        "initial_transitions": np.empty([samples], dtype="int"),
        "in_tagged_transitions": np.empty([samples], dtype="int"),
        "edges": np.empty([samples], dtype="int"),
    }
    for i in progress(range(samples)):
        target_size = np.random.randint(min_object_size, max_object_size)
        example = example_builder(target_size)

        object_sizes[i] = target_size
        data["graph_nodes"][i] = example.graph_metadata.num_nodes
        data["graph_in_tagged_nodes"][i] = (
            example.graph_metadata.num_input_tagged_nodes)
        data["initial_transitions"][i] = (
            example.automaton_graph.initial_to_in_tagged.values.shape[0])
        data["in_tagged_transitions"][i] = (
            example.automaton_graph.in_tagged_to_in_tagged.values.shape[0])
        data["edges"][i] = example.edges.values.shape[0]

    # Fit models.
    print("Fitting a size model...")

    def single_log_likelihood(params, n, x):
        """Log likelihood of a single point under the size model."""
        base_mu, base_std, prop_mu, prop_std = params
        mu = base_mu + n * prop_mu
        var = base_std**2 + n * prop_std**2 + 1e-3
        return -0.5 * (jnp.log(var + 2 * jnp.pi) + (x - mu)**2 / var)

    def compute_loss(params, ns, xs):
        return -jnp.sum(
            jax.vmap(single_log_likelihood, in_axes=(None, 0, 0))(params, ns,
                                                                  xs))

    compute_loss_and_grads = jax.jit(jax.value_and_grad(compute_loss))
    opt_init, opt_update = optax.adam(0.1)

    model_params = {}
    for size_key, values in progress(data.items()):
        params = jnp.array([0., 1., 0., 1.])
        opt_state = opt_init(params)
        last_loss = None
        for i in progress(range(optimization_max_steps)):
            loss, grads = compute_loss_and_grads(params, object_sizes, values)
            if last_loss is not None and np.abs(last_loss -
                                                loss) < log_likelihood_epsilon:
                break
            last_loss = loss
            updates, opt_state = opt_update(grads, opt_state)
            params = optax.apply_updates(params, updates)
        print(
            f"Fit model for {size_key} after {i + 1} iterations, loss was {loss}"
        )
        model_params[size_key] = params

    # Figure out which of the desired sizes is the most constraining.
    print("Solving for padding sizes...")
    desired_sizes = {
        "graph_nodes": desired_sizes.static_max_metadata.num_nodes,
        "graph_in_tagged_nodes":
        desired_sizes.static_max_metadata.num_input_tagged_nodes,
        "initial_transitions": desired_sizes.max_initial_transitions,
        "in_tagged_transitions": desired_sizes.max_in_tagged_transitions,
        "edges": desired_sizes.max_edges,
    }
    ast_constraints = []
    p = success_probability
    for size_key, target in desired_sizes.items():
        # Solve for the n such that the `p`th quantile of the distribution is the
        # target value; this ends up being a quadratic equation.
        base_mu, base_std, prop_mu, prop_std = model_params[size_key]
        erfval = scipy.special.erfinv(2 * p - 1)**2
        a = prop_mu**2
        b = (2 * prop_mu * (base_mu - target) - 2 * prop_std**2 * erfval)
        c = (target - base_mu)**2 - 2 * (base_std**2 + 1e-3) * erfval
        constraint = (-b - np.sqrt(b**2 - 4 * a * c)) / (2 * a)
        ast_constraints.append(constraint)
        print(size_key, "with slope", prop_mu, "constrains to", constraint)

    ast_target_count = int(min(ast_constraints))
    print(f"Target size {ast_target_count} satisfies all constraints with "
          "high probability")

    # Compute the `p`th quantile of each size, maybe rounding up.
    quantiles = {}
    for size_key, params in model_params.items():
        base_mu, base_std, prop_mu, prop_std = params
        mu = base_mu + ast_target_count * prop_mu
        var = base_std**2 + ast_target_count * prop_std**2 + 1e-3
        quantile = mu + np.sqrt(2 * var) * scipy.special.erfinv(2 * p - 1)
        if round_to_powers_of_two:
            # jump down by 0.5 to avoid off-by-one errors with the maximum value
            rounded_quantile = int(np.exp2(np.ceil(np.log2(quantile - 0.5))))
        else:
            rounded_quantile = int(np.ceil(quantile))

        print(f"{size_key}: Rounded {quantile} to {rounded_quantile}")
        quantiles[size_key] = rounded_quantile

    return ast_target_count, graph_bundle.PaddingConfig(
        static_max_metadata=automaton_builder.EncodedGraphMetadata(
            num_nodes=quantiles["graph_nodes"],
            num_input_tagged_nodes=quantiles["graph_in_tagged_nodes"]),
        max_initial_transitions=quantiles["initial_transitions"],
        max_in_tagged_transitions=quantiles["in_tagged_transitions"],
        max_edges=quantiles["edges"],
    )