Exemple #1
0
    def pad_nonzeros(self, nonzeros_axis_size):
        """Pad the number of entries in the operator's `nonzero` axis.

    We can always expand an operator by adding new entries with value 0 without
    changing its meaning. This is useful for batching examples together, for
    instance.

    Args:
      nonzeros_axis_size: Size of the nonzero axis after padding.

    Returns:
      Operator that is equivalent to `self` (in the sense that `apply_add`
      behaves identically) but has `nonzeros_axis_size` as the size of the
      first axis.

    Raises:
      ValueError: If this operator has too many nonzero entries to fit in the
      requested size.
    """
        return SparseCoordOperator(
            input_indices=jax_util.pad_to(self.input_indices,
                                          nonzeros_axis_size),
            output_indices=jax_util.pad_to(self.output_indices,
                                           nonzeros_axis_size),
            values=jax_util.pad_to(self.values, nonzeros_axis_size),
        )
Exemple #2
0
 def test_one_node_particle_estimate_padding(self):
   schema, graph = self.build_doubly_linked_list_graph(4)
   builder = automaton_builder.AutomatonBuilder(schema)
   enc_graph, enc_meta = builder.encode_graph(graph)
   enc_graph_padded = automaton_builder.EncodedGraph(
       initial_to_in_tagged=enc_graph.initial_to_in_tagged.pad_nonzeros(64),
       initial_to_special=jax_util.pad_to(enc_graph.initial_to_special, 64),
       in_tagged_to_in_tagged=(
           enc_graph.in_tagged_to_in_tagged.pad_nonzeros(64)),
       in_tagged_to_special=(jax_util.pad_to(enc_graph.in_tagged_to_special,
                                             64)),
       in_tagged_node_indices=(jax_util.pad_to(
           enc_graph.in_tagged_node_indices, 64)))
   enc_meta_padded = automaton_builder.EncodedGraphMetadata(
       num_nodes=64, num_input_tagged_nodes=64)
   variant_weights = jnp.full([64, 5], 0.2)
   routing_params = automaton_builder.RoutingParams(
       move=jnp.full([5, 6, 2, 2], 0.2), special=jnp.full([5, 3, 2, 3], 0.2))
   tmat = builder.build_transition_matrix(routing_params, enc_graph_padded,
                                          enc_meta_padded)
   outs = automaton_sampling.one_node_particle_estimate(
       builder,
       tmat,
       variant_weights,
       start_machine_state=jnp.array([1., 0.]),
       node_index=0,
       steps=100,
       num_rollouts=100,
       max_possible_transitions=2,
       num_valid_nodes=enc_meta.num_nodes,
       rng=jax.random.PRNGKey(0))
   self.assertEqual(outs.shape, (64,))
   self.assertTrue(jnp.all(outs[:enc_meta.num_nodes] > 0))
   self.assertTrue(jnp.all(outs[enc_meta.num_nodes:] == 0))
 def test_pad_to(self):
     arr = np.arange(15).reshape((3, 5))
     padded = jax_util.pad_to(arr, 7, 1)
     expected = np.array([
         [0, 1, 2, 3, 4, 0, 0],
         [5, 6, 7, 8, 9, 0, 0],
         [10, 11, 12, 13, 14, 0, 0],
     ])
     np.testing.assert_equal(padded, expected)
def pad_example(example,
                config,
                allow_failure = False):
  """Pads an example so that it has a static shape determined by the config.

  Args:
    example: The example to pad.
    config: Configuration specifying the desired padding size.
    allow_failure: If True, returns None instead of failing if example is too
      large.

  Returns:
    A padded example with static shape.

  Raises:
    ValueError: If the graph is too big to pad to this size.
  """
  if example.input_graph.tokens.values.shape[0] > config.max_tokens:
    if allow_failure:
      return None
    raise ValueError("Example has too many tokens.")

  bundle = graph_bundle.pad_example(example.input_graph.bundle,
                                    config.bundle_padding, allow_failure)
  if bundle is None:
    return None
  return VarMisuseExample(
      input_graph=GraphBundleWithTokens(
          bundle=bundle,
          tokens=example.input_graph.tokens.pad_nonzeros(config.max_tokens),
      ),
      bug_node_index=example.bug_node_index,
      repair_node_mask=jax_util.pad_to(
          example.repair_node_mask,
          config.bundle_padding.static_max_metadata.num_nodes),
      candidate_node_mask=jax_util.pad_to(
          example.candidate_node_mask,
          config.bundle_padding.static_max_metadata.num_nodes),
      unique_candidate_operator=example.unique_candidate_operator.pad_nonzeros(
          config.bundle_padding.static_max_metadata.num_nodes),
      repair_id=example.repair_id)
 def _batch_and_pad_elts(*args):
     stacked = np.stack(args)
     stacked = jax_util.pad_to(stacked, batch_sizes[key])
     return stacked.reshape(batch_dim_sizes[key] + stacked.shape[1:])
Exemple #6
0
def pad_example(example,
                config,
                allow_failure = False):
  """Pad an example so that it has a static shape determined by the config.

  The shapes of all NDArrays in the output will be fully determined by the
  config. Note that we do not pad the metadata or num_targets fields, since
  those are already of static shape; the values in those fields can be used
  to determine which elements of the other fields are padding and which elements
  are not.

  Args:
    example: The example to pad.
    config: Configuration specifying the desired padding size.
    allow_failure: If True, returns None instead of failing if example is too
      large.

  Returns:
    A padded example with static shape.

  Raises:
    ValueError: If the graph is too big to pad to this size.
  """
  # Check the size of the example.
  if example.graph_metadata.num_nodes > config.static_max_metadata.num_nodes:
    if allow_failure:
      return None
    raise ValueError("Example has too many nodes")

  if (example.graph_metadata.num_input_tagged_nodes >
      config.static_max_metadata.num_input_tagged_nodes):
    if allow_failure:
      return None
    raise ValueError("Example has too many input-tagged nodes")

  if (example.automaton_graph.initial_to_in_tagged.values.shape[0] >
      config.max_initial_transitions):
    if allow_failure:
      return None
    raise ValueError("Example has too many initial transitions")

  if (example.automaton_graph.in_tagged_to_in_tagged.values.shape[0] >
      config.max_in_tagged_transitions):
    if allow_failure:
      return None
    raise ValueError("Example has too many in-tagged transitions")

  if example.edges.values.shape[0] > config.max_edges:
    if allow_failure:
      return None
    raise ValueError("Example has too many edges")

  # Pad it out.
  return GraphBundle(
      automaton_graph=automaton_builder.EncodedGraph(
          initial_to_in_tagged=example.automaton_graph.initial_to_in_tagged
          .pad_nonzeros(config.max_initial_transitions),
          initial_to_special=jax_util.pad_to(
              example.automaton_graph.initial_to_special,
              config.static_max_metadata.num_nodes),
          in_tagged_to_in_tagged=(
              example.automaton_graph.in_tagged_to_in_tagged.pad_nonzeros(
                  config.max_in_tagged_transitions)),
          in_tagged_to_special=jax_util.pad_to(
              example.automaton_graph.in_tagged_to_special,
              config.static_max_metadata.num_input_tagged_nodes),
          in_tagged_node_indices=jax_util.pad_to(
              example.automaton_graph.in_tagged_node_indices,
              config.static_max_metadata.num_input_tagged_nodes),
      ),
      graph_metadata=example.graph_metadata,
      node_types=jax_util.pad_to(example.node_types,
                                 config.static_max_metadata.num_nodes),
      edges=example.edges.pad_nonzeros(config.max_edges),
  )