def go(routing_params, variant_weights, start_states, explicit_conv=True): tmat = builder.build_transition_matrix(routing_params, enc_graph, enc_meta) return automaton_builder.all_nodes_absorbing_solve( builder, tmat, variant_weights, start_states, steps=1000, backtrack_fails_prob=0.01, explicit_conv=explicit_conv)
def solve_one(one_edge_routing_params, one_edge_variant_weights): """Run one of the edge-specific automata.""" # Build the automaton on the provided graph. transition_matrix = builder.build_transition_matrix( one_edge_routing_params, encoded_graph, static_metadata) # Start state is always state 0. start_machine_states = jnp.broadcast_to( (jnp.arange(num_fsm_states) == 0)[None, :], (num_nodes, num_fsm_states)) return automaton_builder.all_nodes_absorbing_solve( builder=builder, transition_matrix=transition_matrix, variant_weights=one_edge_variant_weights, start_machine_states=start_machine_states, steps=steps, backtrack_fails_prob=backtrack_fails_prob)
def test_all_nodes_absorbing_solve(self): schema, graph = self.build_doubly_linked_list_graph(4) builder = automaton_builder.AutomatonBuilder(schema) enc_graph, enc_meta = builder.encode_graph(graph) # We set up the automaton with 5 variants and 2 states, but only use the # first state, to make sure that variants and states are interleaved # correctly. # Variant 0: move forward # Variant 1: move backward # Variant 2: finish # Variant 3: restart # Variant 4: fail variant_weights = jnp.array([ # From node 0, go forward. [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [.7, 0, .3, 0, 0], [0, 0, 1, 0, 0]], # From node 1, go backward with small failure probabilities. [[0, 0.9, 0, 0, 0.1], [0, 0.9, 0, 0, 0.1], [.7, 0, .3, 0, 0], [0, 0, 1, 0, 0]], # Node 2 bounces around and ultimately accepts on node 0. [[0.9, 0, 0.1, 0, 0], [0, 1, 0, 0, 0], [0.5, 0.5, 0, 0, 0], [0, 1, 0, 0, 0]], # Node 3 immediately accepts, or restarts after 0 or 1 steps. [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0.1, 0.8, 0.1, 0]], ]) routing_params = automaton_builder.RoutingParams( move=jnp.pad( jnp.array([ [1., 0., 1., 0., 1., 0.], [0., 1., 0., 1., 0., 1.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], ]).reshape([5, 6, 1, 1]), [(0, 0), (0, 0), (0, 1), (0, 1)]), special=jnp.pad( jnp.array([ [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]], [[0., 1., 0.], [0., 1., 0.], [0., 1., 0.]], [[0., 0., 1.], [0., 0., 1.], [0., 0., 1.]], ]).reshape([5, 3, 1, 3]), [(0, 0), (0, 0), (0, 1), (0, 0)])) tmat = builder.build_transition_matrix(routing_params, enc_graph, enc_meta) # Absorbing probs follow the paths described above. # Note that when starting at node 3, with probability 0.2 the automaton # tries to backtrack, but with probability 0.2 * 0.01 backtracking fails # (as specified by backtrack_fails_prob) and thus the total absorbing # probability is 0.8 / (0.8 + 0.2 * 0.01) = 0.997506 expected_absorbing_probs = jnp.array([ [0, 0, 0.3, 0.7], [0, 0, 0, 0.81], [1, 0, 0, 0], [0, 0, 0, 0.997506], ]) absorbing_probs = automaton_builder.all_nodes_absorbing_solve( builder, tmat, variant_weights, jnp.pad(jnp.ones([4, 1]), [(0, 0), (0, 1)]), steps=1000, backtrack_fails_prob=0.01) jax.test_util.check_close(absorbing_probs, expected_absorbing_probs)