def go(routing_params, variant_weights, start_states, explicit_conv=True):
   tmat = builder.build_transition_matrix(routing_params, enc_graph,
                                          enc_meta)
   return automaton_builder.all_nodes_absorbing_solve(
       builder,
       tmat,
       variant_weights,
       start_states,
       steps=1000,
       backtrack_fails_prob=0.01,
       explicit_conv=explicit_conv)
Example #2
0
            def solve_one(one_edge_routing_params, one_edge_variant_weights):
                """Run one of the edge-specific automata."""
                # Build the automaton on the provided graph.
                transition_matrix = builder.build_transition_matrix(
                    one_edge_routing_params, encoded_graph, static_metadata)

                # Start state is always state 0.
                start_machine_states = jnp.broadcast_to(
                    (jnp.arange(num_fsm_states) == 0)[None, :],
                    (num_nodes, num_fsm_states))

                return automaton_builder.all_nodes_absorbing_solve(
                    builder=builder,
                    transition_matrix=transition_matrix,
                    variant_weights=one_edge_variant_weights,
                    start_machine_states=start_machine_states,
                    steps=steps,
                    backtrack_fails_prob=backtrack_fails_prob)
Example #3
0
    def test_all_nodes_absorbing_solve(self):
        schema, graph = self.build_doubly_linked_list_graph(4)
        builder = automaton_builder.AutomatonBuilder(schema)
        enc_graph, enc_meta = builder.encode_graph(graph)

        # We set up the automaton with 5 variants and 2 states, but only use the
        # first state, to make sure that variants and states are interleaved
        # correctly.

        # Variant 0: move forward
        # Variant 1: move backward
        # Variant 2: finish
        # Variant 3: restart
        # Variant 4: fail
        variant_weights = jnp.array([
            # From node 0, go forward.
            [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [.7, 0, .3, 0, 0],
             [0, 0, 1, 0, 0]],
            # From node 1, go backward with small failure probabilities.
            [[0, 0.9, 0, 0, 0.1], [0, 0.9, 0, 0, 0.1], [.7, 0, .3, 0, 0],
             [0, 0, 1, 0, 0]],
            # Node 2 bounces around and ultimately accepts on node 0.
            [[0.9, 0, 0.1, 0, 0], [0, 1, 0, 0, 0], [0.5, 0.5, 0, 0, 0],
             [0, 1, 0, 0, 0]],
            # Node 3 immediately accepts, or restarts after 0 or 1 steps.
            [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 1, 0],
             [0, 0.1, 0.8, 0.1, 0]],
        ])
        routing_params = automaton_builder.RoutingParams(
            move=jnp.pad(
                jnp.array([
                    [1., 0., 1., 0., 1., 0.],
                    [0., 1., 0., 1., 0., 1.],
                    [0., 0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0., 0.],
                ]).reshape([5, 6, 1, 1]), [(0, 0), (0, 0), (0, 1), (0, 1)]),
            special=jnp.pad(
                jnp.array([
                    [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
                    [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
                    [[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]],
                    [[0., 1., 0.], [0., 1., 0.], [0., 1., 0.]],
                    [[0., 0., 1.], [0., 0., 1.], [0., 0., 1.]],
                ]).reshape([5, 3, 1, 3]), [(0, 0), (0, 0), (0, 1), (0, 0)]))
        tmat = builder.build_transition_matrix(routing_params, enc_graph,
                                               enc_meta)

        # Absorbing probs follow the paths described above.
        # Note that when starting at node 3, with probability 0.2 the automaton
        # tries to backtrack, but with probability 0.2 * 0.01 backtracking fails
        # (as specified by backtrack_fails_prob) and thus the total absorbing
        # probability is 0.8 / (0.8 + 0.2 * 0.01) = 0.997506
        expected_absorbing_probs = jnp.array([
            [0, 0, 0.3, 0.7],
            [0, 0, 0, 0.81],
            [1, 0, 0, 0],
            [0, 0, 0, 0.997506],
        ])
        absorbing_probs = automaton_builder.all_nodes_absorbing_solve(
            builder,
            tmat,
            variant_weights,
            jnp.pad(jnp.ones([4, 1]), [(0, 0), (0, 1)]),
            steps=1000,
            backtrack_fails_prob=0.01)

        jax.test_util.check_close(absorbing_probs, expected_absorbing_probs)