示例#1
0
 def test_one_node_particle_estimate_padding(self):
   schema, graph = self.build_doubly_linked_list_graph(4)
   builder = automaton_builder.AutomatonBuilder(schema)
   enc_graph, enc_meta = builder.encode_graph(graph)
   enc_graph_padded = automaton_builder.EncodedGraph(
       initial_to_in_tagged=enc_graph.initial_to_in_tagged.pad_nonzeros(64),
       initial_to_special=jax_util.pad_to(enc_graph.initial_to_special, 64),
       in_tagged_to_in_tagged=(
           enc_graph.in_tagged_to_in_tagged.pad_nonzeros(64)),
       in_tagged_to_special=(jax_util.pad_to(enc_graph.in_tagged_to_special,
                                             64)),
       in_tagged_node_indices=(jax_util.pad_to(
           enc_graph.in_tagged_node_indices, 64)))
   enc_meta_padded = automaton_builder.EncodedGraphMetadata(
       num_nodes=64, num_input_tagged_nodes=64)
   variant_weights = jnp.full([64, 5], 0.2)
   routing_params = automaton_builder.RoutingParams(
       move=jnp.full([5, 6, 2, 2], 0.2), special=jnp.full([5, 3, 2, 3], 0.2))
   tmat = builder.build_transition_matrix(routing_params, enc_graph_padded,
                                          enc_meta_padded)
   outs = automaton_sampling.one_node_particle_estimate(
       builder,
       tmat,
       variant_weights,
       start_machine_state=jnp.array([1., 0.]),
       node_index=0,
       steps=100,
       num_rollouts=100,
       max_possible_transitions=2,
       num_valid_nodes=enc_meta.num_nodes,
       rng=jax.random.PRNGKey(0))
   self.assertEqual(outs.shape, (64,))
   self.assertTrue(jnp.all(outs[:enc_meta.num_nodes] > 0))
   self.assertTrue(jnp.all(outs[enc_meta.num_nodes:] == 0))
def zeros_like_padded_example(config):
    """Build an GraphBundle containing only zeros.

  This can be useful to initialize model parameters, or do tests.

  Args:
    config: Configuration specifying the desired padding size.

  Returns:
    An "example" filled with zeros of the given size.
  """
    return GraphBundle(
        automaton_graph=automaton_builder.EncodedGraph(
            initial_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=np.zeros(shape=(config.max_initial_transitions,
                                              1),
                                       dtype=np.int32),
                output_indices=np.zeros(shape=(config.max_initial_transitions,
                                               2),
                                        dtype=np.int32),
                values=np.zeros(shape=(config.max_initial_transitions, ),
                                dtype=np.float32),
            ),
            initial_to_special=np.zeros(
                shape=(config.static_max_metadata.num_nodes, ),
                dtype=np.int32),
            in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=np.zeros(shape=(config.max_in_tagged_transitions,
                                              1),
                                       dtype=np.int32),
                output_indices=np.zeros(
                    shape=(config.max_in_tagged_transitions, 2),
                    dtype=np.int32),
                values=np.zeros(shape=(config.max_in_tagged_transitions, ),
                                dtype=np.float32),
            ),
            in_tagged_to_special=np.zeros(
                shape=(config.static_max_metadata.num_input_tagged_nodes, ),
                dtype=np.int32),
            in_tagged_node_indices=np.zeros(
                shape=(config.static_max_metadata.num_input_tagged_nodes, ),
                dtype=np.int32),
        ),
        graph_metadata=automaton_builder.EncodedGraphMetadata(
            num_nodes=0, num_input_tagged_nodes=0),
        node_types=np.zeros(shape=(config.static_max_metadata.num_nodes, ),
                            dtype=np.int32),
        edges=sparse_operator.SparseCoordOperator(
            input_indices=np.zeros(shape=(config.max_edges, 1),
                                   dtype=np.int32),
            output_indices=np.zeros(shape=(config.max_edges, 2),
                                    dtype=np.int32),
            values=np.zeros(shape=(config.max_edges, ), dtype=np.int32),
        ),
    )
    def test_automaton_layer_abstract_init(self, shared, variant_weights,
                                           use_gate, estimator_type, **kwargs):
        # Create a simple schema and empty encoded graph.
        schema = {
            graph_types.NodeType("a"):
            graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("ai_0")],
                                   out_edges=[graph_types.OutEdgeType("ao_0")
                                              ]),
        }
        builder = automaton_builder.AutomatonBuilder(schema)
        encoded_graph = automaton_builder.EncodedGraph(
            initial_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((128, 1), dtype=jnp.int32),
                output_indices=jnp.zeros((128, 2), dtype=jnp.int32),
                values=jnp.zeros((128, ), dtype=jnp.float32),
            ),
            initial_to_special=jnp.zeros((32, ), dtype=jnp.int32),
            in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((128, 1), dtype=jnp.int32),
                output_indices=jnp.zeros((128, 2), dtype=jnp.int32),
                values=jnp.zeros((128, ), dtype=jnp.float32),
            ),
            in_tagged_to_special=jnp.zeros((64, ), dtype=jnp.int32),
            in_tagged_node_indices=jnp.zeros((64, ), dtype=jnp.int32),
        )

        # Make sure the layer can be initialized and applied within a model.
        # This model is fairly simple; it just pretends that the encoded graph and
        # variants depend on the input.
        class TestModel(flax.deprecated.nn.Module):
            def apply(self, dummy_ignored):
                abstract_encoded_graph = jax.tree_map(
                    lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph)
                abstract_variant_weights = jax.tree_map(
                    lambda y: jax.lax.tie_in(dummy_ignored, y),
                    variant_weights())
                return automaton_layer.FiniteStateGraphAutomaton(
                    encoded_graph=abstract_encoded_graph,
                    variant_weights=abstract_variant_weights,
                    dynamic_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=32, num_input_tagged_nodes=64),
                    static_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=32, num_input_tagged_nodes=64),
                    builder=builder,
                    num_out_edges=3,
                    num_intermediate_states=4,
                    share_states_across_edges=shared,
                    use_gate_parameterization=use_gate,
                    estimator_type=estimator_type,
                    name="the_layer",
                    **kwargs)

        with side_outputs.collect_side_outputs() as side:
            with flax.deprecated.nn.stochastic(jax.random.PRNGKey(0)):
                # For some reason init_by_shape breaks the custom_vjp?
                abstract_out, unused_params = TestModel.init(
                    jax.random.PRNGKey(1234), jnp.zeros((), jnp.float32))

        del unused_params
        self.assertEqual(abstract_out.shape, (3, 32, 32))

        if estimator_type == "one_sample":
            log_prob_key = "/the_layer/one_sample_log_prob_per_edge_per_node"
            self.assertIn(log_prob_key, side)
            self.assertEqual(side[log_prob_key].shape, (3, 32))
示例#4
0
def pad_example(example,
                config,
                allow_failure = False):
  """Pad an example so that it has a static shape determined by the config.

  The shapes of all NDArrays in the output will be fully determined by the
  config. Note that we do not pad the metadata or num_targets fields, since
  those are already of static shape; the values in those fields can be used
  to determine which elements of the other fields are padding and which elements
  are not.

  Args:
    example: The example to pad.
    config: Configuration specifying the desired padding size.
    allow_failure: If True, returns None instead of failing if example is too
      large.

  Returns:
    A padded example with static shape.

  Raises:
    ValueError: If the graph is too big to pad to this size.
  """
  # Check the size of the example.
  if example.graph_metadata.num_nodes > config.static_max_metadata.num_nodes:
    if allow_failure:
      return None
    raise ValueError("Example has too many nodes")

  if (example.graph_metadata.num_input_tagged_nodes >
      config.static_max_metadata.num_input_tagged_nodes):
    if allow_failure:
      return None
    raise ValueError("Example has too many input-tagged nodes")

  if (example.automaton_graph.initial_to_in_tagged.values.shape[0] >
      config.max_initial_transitions):
    if allow_failure:
      return None
    raise ValueError("Example has too many initial transitions")

  if (example.automaton_graph.in_tagged_to_in_tagged.values.shape[0] >
      config.max_in_tagged_transitions):
    if allow_failure:
      return None
    raise ValueError("Example has too many in-tagged transitions")

  if example.edges.values.shape[0] > config.max_edges:
    if allow_failure:
      return None
    raise ValueError("Example has too many edges")

  # Pad it out.
  return GraphBundle(
      automaton_graph=automaton_builder.EncodedGraph(
          initial_to_in_tagged=example.automaton_graph.initial_to_in_tagged
          .pad_nonzeros(config.max_initial_transitions),
          initial_to_special=jax_util.pad_to(
              example.automaton_graph.initial_to_special,
              config.static_max_metadata.num_nodes),
          in_tagged_to_in_tagged=(
              example.automaton_graph.in_tagged_to_in_tagged.pad_nonzeros(
                  config.max_in_tagged_transitions)),
          in_tagged_to_special=jax_util.pad_to(
              example.automaton_graph.in_tagged_to_special,
              config.static_max_metadata.num_input_tagged_nodes),
          in_tagged_node_indices=jax_util.pad_to(
              example.automaton_graph.in_tagged_node_indices,
              config.static_max_metadata.num_input_tagged_nodes),
      ),
      graph_metadata=example.graph_metadata,
      node_types=jax_util.pad_to(example.node_types,
                                 config.static_max_metadata.num_nodes),
      edges=example.edges.pad_nonzeros(config.max_edges),
  )