def test_one_node_particle_estimate_padding(self): schema, graph = self.build_doubly_linked_list_graph(4) builder = automaton_builder.AutomatonBuilder(schema) enc_graph, enc_meta = builder.encode_graph(graph) enc_graph_padded = automaton_builder.EncodedGraph( initial_to_in_tagged=enc_graph.initial_to_in_tagged.pad_nonzeros(64), initial_to_special=jax_util.pad_to(enc_graph.initial_to_special, 64), in_tagged_to_in_tagged=( enc_graph.in_tagged_to_in_tagged.pad_nonzeros(64)), in_tagged_to_special=(jax_util.pad_to(enc_graph.in_tagged_to_special, 64)), in_tagged_node_indices=(jax_util.pad_to( enc_graph.in_tagged_node_indices, 64))) enc_meta_padded = automaton_builder.EncodedGraphMetadata( num_nodes=64, num_input_tagged_nodes=64) variant_weights = jnp.full([64, 5], 0.2) routing_params = automaton_builder.RoutingParams( move=jnp.full([5, 6, 2, 2], 0.2), special=jnp.full([5, 3, 2, 3], 0.2)) tmat = builder.build_transition_matrix(routing_params, enc_graph_padded, enc_meta_padded) outs = automaton_sampling.one_node_particle_estimate( builder, tmat, variant_weights, start_machine_state=jnp.array([1., 0.]), node_index=0, steps=100, num_rollouts=100, max_possible_transitions=2, num_valid_nodes=enc_meta.num_nodes, rng=jax.random.PRNGKey(0)) self.assertEqual(outs.shape, (64,)) self.assertTrue(jnp.all(outs[:enc_meta.num_nodes] > 0)) self.assertTrue(jnp.all(outs[enc_meta.num_nodes:] == 0))
def zeros_like_padded_example(config): """Build an GraphBundle containing only zeros. This can be useful to initialize model parameters, or do tests. Args: config: Configuration specifying the desired padding size. Returns: An "example" filled with zeros of the given size. """ return GraphBundle( automaton_graph=automaton_builder.EncodedGraph( initial_to_in_tagged=sparse_operator.SparseCoordOperator( input_indices=np.zeros(shape=(config.max_initial_transitions, 1), dtype=np.int32), output_indices=np.zeros(shape=(config.max_initial_transitions, 2), dtype=np.int32), values=np.zeros(shape=(config.max_initial_transitions, ), dtype=np.float32), ), initial_to_special=np.zeros( shape=(config.static_max_metadata.num_nodes, ), dtype=np.int32), in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator( input_indices=np.zeros(shape=(config.max_in_tagged_transitions, 1), dtype=np.int32), output_indices=np.zeros( shape=(config.max_in_tagged_transitions, 2), dtype=np.int32), values=np.zeros(shape=(config.max_in_tagged_transitions, ), dtype=np.float32), ), in_tagged_to_special=np.zeros( shape=(config.static_max_metadata.num_input_tagged_nodes, ), dtype=np.int32), in_tagged_node_indices=np.zeros( shape=(config.static_max_metadata.num_input_tagged_nodes, ), dtype=np.int32), ), graph_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=0, num_input_tagged_nodes=0), node_types=np.zeros(shape=(config.static_max_metadata.num_nodes, ), dtype=np.int32), edges=sparse_operator.SparseCoordOperator( input_indices=np.zeros(shape=(config.max_edges, 1), dtype=np.int32), output_indices=np.zeros(shape=(config.max_edges, 2), dtype=np.int32), values=np.zeros(shape=(config.max_edges, ), dtype=np.int32), ), )
def test_automaton_layer_abstract_init(self, shared, variant_weights, use_gate, estimator_type, **kwargs): # Create a simple schema and empty encoded graph. schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("ai_0")], out_edges=[graph_types.OutEdgeType("ao_0") ]), } builder = automaton_builder.AutomatonBuilder(schema) encoded_graph = automaton_builder.EncodedGraph( initial_to_in_tagged=sparse_operator.SparseCoordOperator( input_indices=jnp.zeros((128, 1), dtype=jnp.int32), output_indices=jnp.zeros((128, 2), dtype=jnp.int32), values=jnp.zeros((128, ), dtype=jnp.float32), ), initial_to_special=jnp.zeros((32, ), dtype=jnp.int32), in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator( input_indices=jnp.zeros((128, 1), dtype=jnp.int32), output_indices=jnp.zeros((128, 2), dtype=jnp.int32), values=jnp.zeros((128, ), dtype=jnp.float32), ), in_tagged_to_special=jnp.zeros((64, ), dtype=jnp.int32), in_tagged_node_indices=jnp.zeros((64, ), dtype=jnp.int32), ) # Make sure the layer can be initialized and applied within a model. # This model is fairly simple; it just pretends that the encoded graph and # variants depend on the input. class TestModel(flax.deprecated.nn.Module): def apply(self, dummy_ignored): abstract_encoded_graph = jax.tree_map( lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph) abstract_variant_weights = jax.tree_map( lambda y: jax.lax.tie_in(dummy_ignored, y), variant_weights()) return automaton_layer.FiniteStateGraphAutomaton( encoded_graph=abstract_encoded_graph, variant_weights=abstract_variant_weights, dynamic_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=32, num_input_tagged_nodes=64), static_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=32, num_input_tagged_nodes=64), builder=builder, num_out_edges=3, num_intermediate_states=4, share_states_across_edges=shared, use_gate_parameterization=use_gate, estimator_type=estimator_type, name="the_layer", **kwargs) with side_outputs.collect_side_outputs() as side: with flax.deprecated.nn.stochastic(jax.random.PRNGKey(0)): # For some reason init_by_shape breaks the custom_vjp? abstract_out, unused_params = TestModel.init( jax.random.PRNGKey(1234), jnp.zeros((), jnp.float32)) del unused_params self.assertEqual(abstract_out.shape, (3, 32, 32)) if estimator_type == "one_sample": log_prob_key = "/the_layer/one_sample_log_prob_per_edge_per_node" self.assertIn(log_prob_key, side) self.assertEqual(side[log_prob_key].shape, (3, 32))
def pad_example(example, config, allow_failure = False): """Pad an example so that it has a static shape determined by the config. The shapes of all NDArrays in the output will be fully determined by the config. Note that we do not pad the metadata or num_targets fields, since those are already of static shape; the values in those fields can be used to determine which elements of the other fields are padding and which elements are not. Args: example: The example to pad. config: Configuration specifying the desired padding size. allow_failure: If True, returns None instead of failing if example is too large. Returns: A padded example with static shape. Raises: ValueError: If the graph is too big to pad to this size. """ # Check the size of the example. if example.graph_metadata.num_nodes > config.static_max_metadata.num_nodes: if allow_failure: return None raise ValueError("Example has too many nodes") if (example.graph_metadata.num_input_tagged_nodes > config.static_max_metadata.num_input_tagged_nodes): if allow_failure: return None raise ValueError("Example has too many input-tagged nodes") if (example.automaton_graph.initial_to_in_tagged.values.shape[0] > config.max_initial_transitions): if allow_failure: return None raise ValueError("Example has too many initial transitions") if (example.automaton_graph.in_tagged_to_in_tagged.values.shape[0] > config.max_in_tagged_transitions): if allow_failure: return None raise ValueError("Example has too many in-tagged transitions") if example.edges.values.shape[0] > config.max_edges: if allow_failure: return None raise ValueError("Example has too many edges") # Pad it out. return GraphBundle( automaton_graph=automaton_builder.EncodedGraph( initial_to_in_tagged=example.automaton_graph.initial_to_in_tagged .pad_nonzeros(config.max_initial_transitions), initial_to_special=jax_util.pad_to( example.automaton_graph.initial_to_special, config.static_max_metadata.num_nodes), in_tagged_to_in_tagged=( example.automaton_graph.in_tagged_to_in_tagged.pad_nonzeros( config.max_in_tagged_transitions)), in_tagged_to_special=jax_util.pad_to( example.automaton_graph.in_tagged_to_special, config.static_max_metadata.num_input_tagged_nodes), in_tagged_node_indices=jax_util.pad_to( example.automaton_graph.in_tagged_node_indices, config.static_max_metadata.num_input_tagged_nodes), ), graph_metadata=example.graph_metadata, node_types=jax_util.pad_to(example.node_types, config.static_max_metadata.num_nodes), edges=example.edges.pad_nonzeros(config.max_edges), )