예제 #1
0
 def test_one_node_particle_estimate_padding(self):
   schema, graph = self.build_doubly_linked_list_graph(4)
   builder = automaton_builder.AutomatonBuilder(schema)
   enc_graph, enc_meta = builder.encode_graph(graph)
   enc_graph_padded = automaton_builder.EncodedGraph(
       initial_to_in_tagged=enc_graph.initial_to_in_tagged.pad_nonzeros(64),
       initial_to_special=jax_util.pad_to(enc_graph.initial_to_special, 64),
       in_tagged_to_in_tagged=(
           enc_graph.in_tagged_to_in_tagged.pad_nonzeros(64)),
       in_tagged_to_special=(jax_util.pad_to(enc_graph.in_tagged_to_special,
                                             64)),
       in_tagged_node_indices=(jax_util.pad_to(
           enc_graph.in_tagged_node_indices, 64)))
   enc_meta_padded = automaton_builder.EncodedGraphMetadata(
       num_nodes=64, num_input_tagged_nodes=64)
   variant_weights = jnp.full([64, 5], 0.2)
   routing_params = automaton_builder.RoutingParams(
       move=jnp.full([5, 6, 2, 2], 0.2), special=jnp.full([5, 3, 2, 3], 0.2))
   tmat = builder.build_transition_matrix(routing_params, enc_graph_padded,
                                          enc_meta_padded)
   outs = automaton_sampling.one_node_particle_estimate(
       builder,
       tmat,
       variant_weights,
       start_machine_state=jnp.array([1., 0.]),
       node_index=0,
       steps=100,
       num_rollouts=100,
       max_possible_transitions=2,
       num_valid_nodes=enc_meta.num_nodes,
       rng=jax.random.PRNGKey(0))
   self.assertEqual(outs.shape, (64,))
   self.assertTrue(jnp.all(outs[:enc_meta.num_nodes] > 0))
   self.assertTrue(jnp.all(outs[enc_meta.num_nodes:] == 0))
 def go(variant_weights, routing_params, eps):
     variant_weights = ((1 - eps) * variant_weights +
                        eps * jnp.ones_like(variant_weights) / 5)
     routing_params = automaton_builder.RoutingParams(
         move=((1 - eps) * routing_params.move +
               eps * jnp.ones_like(routing_params.move) / 5),
         special=((1 - eps) * routing_params.special +
                  eps * jnp.ones_like(routing_params.special) / 5),
     )
     variant_weights = variant_weights / jnp.sum(
         variant_weights, axis=-1, keepdims=True)
     routing_params_sum = builder.routing_reduce(routing_params, "sum")
     routing_params = jax.tree_multimap(jax.lax.div, routing_params,
                                        routing_params_sum)
     tmat = builder.build_transition_matrix(routing_params, enc_graph,
                                            enc_meta)
     return automaton_sampling.all_nodes_particle_estimate(
         builder,
         tmat,
         variant_weights,
         jnp.pad(jnp.ones([4, 1]), [(0, 0), (0, 1)]),
         steps=100,
         rng=jax.random.PRNGKey(1),
         num_rollouts=10000,
         max_possible_transitions=2,  # only two edges to leave each node
         num_valid_nodes=enc_meta.num_nodes,
     )
예제 #3
0
    def test_transition_sentinel_integers(self):
        """Test that the transition matrix puts each element in the right place."""
        schema, test_graph = self.build_loop_graph()
        builder = automaton_builder.AutomatonBuilder(schema)
        encoded_graph, graph_meta = builder.encode_graph(test_graph,
                                                         as_jax=False)

        # Apply to some sentinel integers to check correct indexing (with only one
        # variant and state, since indexing doesn't use those)
        # Each integer is of the form XYZ where
        #   X = {a:1, b:2}[node_type]
        #   Y = {initial:9, i0:0, i1:1}[in_edge]
        #   Z = {o0:0, o1:1, o2:2, finish:3, backtrack:4, fail:5}[action]
        sentinel_routing_params = automaton_builder.RoutingParams(
            move=jnp.array([[100, 101, 110, 111, 190, 191],
                            [200, 201, 202, 290, 291, 292]]).reshape(
                                (1, 12, 1, 1)),
            special=jnp.array([
                [103, 104, 105],
                [113, 114, 115],
                [193, 194, 195],
                [203, 204, 205],
                [293, 294, 295],
            ]).reshape((1, 5, 1, 3)))

        range_transition_matrix = builder.build_transition_matrix(
            sentinel_routing_params,
            encoded_graph,
            graph_meta,
        ).concatenated_transitions()
        self.assertEqual(range_transition_matrix.shape,
                         (1, 4 + 6, 1, 6 * 1 + 3))

        # pyformat: disable
        # pylint: disable=bad-continuation,bad-whitespace,g-inline-comment-too-close
        expected = np.array([
            #  a0i0 a0i1 a1i0 a1i1       b0i0       b1i0   specials    < next
            [  #   |    |    |    |          |          |    |---------|   current V
                [[0, 0, 191, 0, 0, 190, 193, 194, 195]],  # ┬ a0
                [[0, 190, 0, 0, 191, 0, 193, 194, 195]],  # | a1
                [[0, 0, 0, 290, 292 / 2, 291 + 292 / 2, 293, 294,
                  295]],  # | b0
                [[291, 0, 0, 0, 290 + 292 / 2, 292 / 2, 293, 294,
                  295]],  # ┴ b1
                [[0, 0, 101, 0, 0, 100, 103, 104, 105]],  # ┬ a0i0
                [[0, 0, 111, 0, 0, 110, 113, 114, 115]],  # | a0i1
                [[0, 100, 0, 0, 101, 0, 103, 104, 105]],  # | a1i0
                [[0, 110, 0, 0, 111, 0, 113, 114, 115]],  # | a1i1
                [[0, 0, 0, 200, 202 / 2, 201 + 202 / 2, 203, 204,
                  205]],  # | b0i0
                [[201, 0, 0, 0, 200 + 202 / 2, 202 / 2, 203, 204,
                  205]],  # ┴ b0i1
            ]
        ])
        # pyformat: enable
        # pylint: enable=bad-continuation,bad-whitespace,g-inline-comment-too-close
        np.testing.assert_allclose(range_transition_matrix, expected)
  def test_transition_all_ones(self):
    """Test the transition matrix of an all-ones routing parameter vector."""
    schema, test_graph = self.build_loop_graph()
    builder = automaton_builder.AutomatonBuilder(schema)
    encoded_graph, graph_meta = builder.encode_graph(test_graph, as_jax=False)

    # The transition matrix for an all-ones routing params should be a
    # (weighted) directed adjacency matrix. We use 3 variants, 2 states.
    ones_routing_params = automaton_builder.RoutingParams(
        move=jnp.ones([3, 12, 2, 2]), special=jnp.ones([3, 5, 2, 3]))

    ones_transition_matrix = builder.build_transition_matrix(
        ones_routing_params,
        encoded_graph,
        graph_meta,
    ).concatenated_transitions()
    self.assertEqual(ones_transition_matrix.shape, (3, 4 + 6, 2, 6 * 2 + 3))

    # pyformat: disable
    # pylint: disable=bad-continuation,bad-whitespace,g-inline-comment-too-close
    expected = np.array([
        #  a0i0    a0i1    a1i0    a1i1    b0i0    b1i0    specials  < next
      [ #|------|-------|-------|-------|-------|-------| |--------| current V
        [[ 0,  0,  0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  1,  1,  1],  # ┬ a0
         [ 0,  0,  0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  1,  1,  1]], # |
        [[ 0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  0,  0,  1,  1,  1],  # | a1
         [ 0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  0,  0,  1,  1,  1]], # |
        [[ 0,  0,  0,  0,  0,  0,  1,  1,0.5,0.5,1.5,1.5,  1,  1,  1],  # | b0
         [ 0,  0,  0,  0,  0,  0,  1,  1,0.5,0.5,1.5,1.5,  1,  1,  1]], # |
        [[ 1,  1,  0,  0,  0,  0,  0,  0,1.5,1.5,0.5,0.5,  1,  1,  1],  # | b1
         [ 1,  1,  0,  0,  0,  0,  0,  0,1.5,1.5,0.5,0.5,  1,  1,  1]], # ┴
        [[ 0,  0,  0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  1,  1,  1],  # ┬ a0i0
         [ 0,  0,  0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  1,  1,  1]], # |
        [[ 0,  0,  0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  1,  1,  1],  # | a0i1
         [ 0,  0,  0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  1,  1,  1]], # |
        [[ 0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  0,  0,  1,  1,  1],  # | a1i0
         [ 0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  0,  0,  1,  1,  1]], # |
        [[ 0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  0,  0,  1,  1,  1],  # | a1i1
         [ 0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  0,  0,  1,  1,  1]], # |
        [[ 0,  0,  0,  0,  0,  0,  1,  1,0.5,0.5,1.5,1.5,  1,  1,  1],  # | b0i0
         [ 0,  0,  0,  0,  0,  0,  1,  1,0.5,0.5,1.5,1.5,  1,  1,  1]], # |
        [[ 1,  1,  0,  0,  0,  0,  0,  0,1.5,1.5,0.5,0.5,  1,  1,  1],  # | b0i1
         [ 1,  1,  0,  0,  0,  0,  0,  0,1.5,1.5,0.5,0.5,  1,  1,  1]], # ┴
      ]
    ] * 3)
    # pyformat: enable
    # pylint: enable=bad-continuation,bad-whitespace,g-inline-comment-too-close
    np.testing.assert_allclose(ones_transition_matrix, expected)
    def test_all_nodes_particle_estimate(self):
        schema, graph = self.build_doubly_linked_list_graph(4)
        builder = automaton_builder.AutomatonBuilder(schema)
        enc_graph, enc_meta = builder.encode_graph(graph)

        # We set up the automaton with 5 variants and 2 states, but only use the
        # first state, to make sure that variants and states are interleaved
        # correctly.

        # Variant 0: move forward
        # Variant 1: move backward
        # Variant 2: finish
        # Variant 3: restart
        # Variant 4: fail
        variant_weights = jnp.array([
            # From node 0, go forward.
            [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [.7, 0, .3, 0, 0],
             [0, 0, 1, 0, 0]],
            # From node 1, go backward with small failure probabilities.
            [[0, 0.9, 0, 0, 0.1], [0, 0.9, 0, 0, 0.1], [.7, 0, .3, 0, 0],
             [0, 0, 1, 0, 0]],
            # Node 2 bounces around and ultimately accepts on node 0.
            [[0.9, 0, 0.1, 0, 0], [0, 1, 0, 0, 0], [0.5, 0.5, 0, 0, 0],
             [0, 1, 0, 0, 0]],
            # Node 3 immediately accepts, or restarts after 0 or 1 steps.
            [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 1, 0],
             [0, 0.1, 0.8, 0.1, 0]],
        ])
        routing_params = automaton_builder.RoutingParams(
            move=jnp.broadcast_to(
                jnp.pad(
                    jnp.array([
                        [1., 0., 1., 0., 1., 0.],
                        [0., 1., 0., 1., 0., 1.],
                        [0., 0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0., 0.],
                    ]).reshape([5, 6, 1, 1]), [(0, 0), (0, 0), (0, 0),
                                               (0, 1)]), [5, 6, 2, 2]),
            special=jnp.broadcast_to(
                jnp.array([
                    [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
                    [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
                    [[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]],
                    [[0., 1., 0.], [0., 1., 0.], [0., 1., 0.]],
                    [[0., 0., 1.], [0., 0., 1.], [0., 0., 1.]],
                ]).reshape([5, 3, 1, 3]), [5, 3, 2, 3]))

        @jax.jit
        def go(variant_weights, routing_params, eps):
            variant_weights = ((1 - eps) * variant_weights +
                               eps * jnp.ones_like(variant_weights) / 5)
            routing_params = automaton_builder.RoutingParams(
                move=((1 - eps) * routing_params.move +
                      eps * jnp.ones_like(routing_params.move) / 5),
                special=((1 - eps) * routing_params.special +
                         eps * jnp.ones_like(routing_params.special) / 5),
            )
            variant_weights = variant_weights / jnp.sum(
                variant_weights, axis=-1, keepdims=True)
            routing_params_sum = builder.routing_reduce(routing_params, "sum")
            routing_params = jax.tree_multimap(jax.lax.div, routing_params,
                                               routing_params_sum)
            tmat = builder.build_transition_matrix(routing_params, enc_graph,
                                                   enc_meta)
            return automaton_sampling.all_nodes_particle_estimate(
                builder,
                tmat,
                variant_weights,
                jnp.pad(jnp.ones([4, 1]), [(0, 0), (0, 1)]),
                steps=100,
                rng=jax.random.PRNGKey(1),
                num_rollouts=10000,
                max_possible_transitions=2,  # only two edges to leave each node
                num_valid_nodes=enc_meta.num_nodes,
            )

        # Absorbing probs follow the paths described above.
        expected_absorbing_probs = jnp.array([
            [0, 0, 0.3, 0.7],
            [0, 0, 0, 0.81],
            [1, 0, 0, 0],
            [0, 0, 0, 1],
        ])

        particle_probs = go(variant_weights, routing_params, eps=0)

        # With 10000 rollouts we expect a standard deviation of up to
        # sqrt(0.3*0.7/10000) ~= 5e-3. Check that we are within 2 sigma.
        np.testing.assert_allclose(particle_probs,
                                   expected_absorbing_probs,
                                   atol=2 * 5e-3)
예제 #6
0
    def test_unroll_and_aggregate(self):
        schema, graph = self.build_doubly_linked_list_graph(4)
        builder = automaton_builder.AutomatonBuilder(schema)
        enc_graph, enc_meta = builder.encode_graph(graph)

        # Three variants, but all the same, just to check shape consistency
        variant_weights = jnp.broadcast_to(jnp.array([0.7, 0.3, 0.]), [4, 3])

        # In state 0: keep moving in the current direction with prob 0.1, swap
        # directions and states with prob 0.9 (except init, which is a special case)
        # In state 1: take a special action
        routing_params = automaton_builder.RoutingParams(
            move=jnp.broadcast_to(
                jnp.array([
                    # from next, to next
                    [[0.0, 0.9], [0.0, 0.0]],
                    # from next, to prev
                    [[0.1, 0.0], [0.0, 0.0]],
                    # from prev, to next
                    [[0.1, 0.0], [0.0, 0.0]],
                    # from next, to prev
                    [[0.0, 0.9], [0.0, 0.0]],
                    # from init, to next
                    [[0.1, 0.0], [0.0, 0.0]],
                    # from init, to prev
                    [[0.1, 0.0], [0.0, 0.0]],
                ]),
                [3, 6, 2, 2]),
            special=jnp.broadcast_to(
                jnp.array([
                    # from next
                    [[0, 0, 0], [0.2, 0.3, 0.5]],
                    # from prev
                    [[0, 0, 0], [0.5, 0.2, 0.3]],
                    # from init
                    [[0.1, 0.3, 0.4], [0.0, 0.0, 1.0]],
                ]),
                [3, 3, 2, 3]))
        tmat = builder.build_transition_matrix(routing_params, enc_graph,
                                               enc_meta)

        unrolled = automaton_builder.unroll_chain_steps(builder,
                                                        tmat,
                                                        variant_weights,
                                                        jnp.array([1., 0.]),
                                                        node_index=0,
                                                        steps=6)

        expected_initial_special = np.array([0.1, 0.3, 0.4])
        np.testing.assert_allclose(unrolled["initial_special"],
                                   expected_initial_special)

        # pyformat: disable
        # pylint: disable=bad-whitespace
        expected_in_tagged_states = np.array([
            # In-tagged node key:
            #  0 from 1,     0 from 3,     1 from 2,     1 from 0,
            #  2 from 3,     2 from 1,     3 from 0,     3 from 2
            # State key: [prob of being in state 0, prob of being in state 1]
            # -- First step from initial --
            [[0, 0], [0, 0], [0, 0], [1e-1, 0], [0, 0], [0, 0], [1e-1, 0],
             [0, 0]],
            # -- Second step --
            [[0, 9e-2], [0, 9e-2], [0, 0], [0, 0], [1e-2, 0], [1e-2, 0],
             [0, 0], [0, 0]],
            # -----------------
            [[0, 0], [0, 0], [1e-3, 9e-3], [0, 0], [0, 0], [0, 0], [0, 0],
             [1e-3, 9e-3]],
            # -----------------
            [[1e-4, 0], [1e-4, 0], [0, 0], [0, 0], [0, 9e-4], [0, 9e-4],
             [0, 0], [0, 0]],
            # -----------------
            [[0, 0], [0, 0], [0, 0], [1e-5, 9e-5], [0, 0], [0, 0],
             [1e-5, 9e-5], [0, 0]],
            # -----------------
            [[0, 9e-6], [0, 9e-6], [0, 0], [0, 0], [1e-6, 0], [1e-6, 0],
             [0, 0], [0, 0]],
        ])
        np.testing.assert_allclose(unrolled["in_tagged_states"],
                                   expected_in_tagged_states,
                                   atol=1e-8)

        expected_in_tagged_special = np.array([
            # In-tagged node key:
            #  0 from 1,                 0 from 3
            #  1 from 2,                 1 from 0,
            #  2 from 3,                 2 from 1
            #  3 from 0,                 3 from 2
            # Action key: [finish, backtrack, fail] (cumulative)
            # -- First step from initial --
            # (no special actions because we just left the initial node)
            [[0, 0, 0]] * 8,
            # -- Second step --
            # (no special actions yet because everything was in state 0)
            [[0, 0, 0]] * 8,
            # -----------------
            [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2], [0, 0, 0],
             [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
            # -----------------
            [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2],
             [1.8e-3, 2.7e-3, 4.5e-3], [0, 0, 0], [0, 0, 0], [0, 0, 0],
             [0, 0, 0], [4.5e-3, 1.8e-3, 2.7e-3]],
            # -----------------
            [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2],
             [1.8e-3, 2.7e-3, 4.5e-3], [0, 0, 0], [1.8e-4, 2.7e-4, 4.5e-4],
             [4.5e-4, 1.8e-4, 2.7e-4], [0, 0, 0], [4.5e-3, 1.8e-3, 2.7e-3]],
            # -----------------
            [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2],
             [1.8e-3, 2.7e-3, 4.5e-3], [4.5e-5, 1.8e-5, 2.7e-5],
             [1.8e-4, 2.7e-4, 4.5e-4], [4.5e-4, 1.8e-4, 2.7e-4],
             [1.8e-5, 2.7e-5, 4.5e-5], [4.5e-3, 1.8e-3, 2.7e-3]],
        ])
        np.testing.assert_allclose(unrolled["in_tagged_special"],
                                   expected_in_tagged_special,
                                   atol=1e-8)
        # pyformat: enable
        # pylint: enable=bad-whitespace

        unrolled_combined = automaton_builder.aggregate_unrolled_per_node(
            unrolled, 0, 0, tmat, enc_meta)

        expected_unrolled = np.array([
            # "Zeroth" step: at initial node, no specials have happened yet
            [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0]],
            # The other steps are either the states from above, or sums of entries
            # in initial_special and in_tagged_special
            [[0, 0, 0.1, 0.3, 0.4], [0.1, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0.1, 0, 0, 0, 0]],
            # -----------------
            [[0, 0.18, 0.1, 0.3, 0.4], [0, 0, 0, 0, 0], [0.02, 0, 0, 0, 0],
             [0, 0, 0, 0, 0]],
            # -----------------
            [[0, 0, 0.163, 0.34500003, 0.472], [0.001, 0.009, 0, 0, 0],
             [0, 0, 0, 0, 0], [0.001, 0.009, 0, 0, 0]],
            # -----------------
            [[0.0002, 0, 0.163, 0.34500003, 0.472],
             [0, 0, 0.0018, 0.0027, 0.0045], [0, 0.0018, 0, 0, 0],
             [0, 0, 0.0045, 0.0018, 0.0027]],
            # -----------------
            [[0, 0, 0.163, 0.34500003, 0.472],
             [1e-05, 9e-05, 0.0018, 0.0027, 0.0045],
             [0, 0, 0.00063, 0.00045, 0.00072],
             [1e-05, 9e-05, 0.0045, 0.0018, 0.0027]],
            # -----------------
            [[0, 1.8e-05, 0.163, 0.34500003, 0.472],
             [0, 0, 0.001845, 0.002718, 0.004527],
             [2e-06, 0, 0.00063, 0.00045, 0.00072],
             [0, 0, 0.004518, 0.001827, 0.002745]],
        ])

        np.testing.assert_allclose(unrolled_combined,
                                   expected_unrolled,
                                   atol=1e-8)
예제 #7
0
    def test_all_nodes_absorbing_solve(self):
        schema, graph = self.build_doubly_linked_list_graph(4)
        builder = automaton_builder.AutomatonBuilder(schema)
        enc_graph, enc_meta = builder.encode_graph(graph)

        # We set up the automaton with 5 variants and 2 states, but only use the
        # first state, to make sure that variants and states are interleaved
        # correctly.

        # Variant 0: move forward
        # Variant 1: move backward
        # Variant 2: finish
        # Variant 3: restart
        # Variant 4: fail
        variant_weights = jnp.array([
            # From node 0, go forward.
            [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [.7, 0, .3, 0, 0],
             [0, 0, 1, 0, 0]],
            # From node 1, go backward with small failure probabilities.
            [[0, 0.9, 0, 0, 0.1], [0, 0.9, 0, 0, 0.1], [.7, 0, .3, 0, 0],
             [0, 0, 1, 0, 0]],
            # Node 2 bounces around and ultimately accepts on node 0.
            [[0.9, 0, 0.1, 0, 0], [0, 1, 0, 0, 0], [0.5, 0.5, 0, 0, 0],
             [0, 1, 0, 0, 0]],
            # Node 3 immediately accepts, or restarts after 0 or 1 steps.
            [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 1, 0],
             [0, 0.1, 0.8, 0.1, 0]],
        ])
        routing_params = automaton_builder.RoutingParams(
            move=jnp.pad(
                jnp.array([
                    [1., 0., 1., 0., 1., 0.],
                    [0., 1., 0., 1., 0., 1.],
                    [0., 0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0., 0.],
                ]).reshape([5, 6, 1, 1]), [(0, 0), (0, 0), (0, 1), (0, 1)]),
            special=jnp.pad(
                jnp.array([
                    [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
                    [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
                    [[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]],
                    [[0., 1., 0.], [0., 1., 0.], [0., 1., 0.]],
                    [[0., 0., 1.], [0., 0., 1.], [0., 0., 1.]],
                ]).reshape([5, 3, 1, 3]), [(0, 0), (0, 0), (0, 1), (0, 0)]))
        tmat = builder.build_transition_matrix(routing_params, enc_graph,
                                               enc_meta)

        # Absorbing probs follow the paths described above.
        # Note that when starting at node 3, with probability 0.2 the automaton
        # tries to backtrack, but with probability 0.2 * 0.01 backtracking fails
        # (as specified by backtrack_fails_prob) and thus the total absorbing
        # probability is 0.8 / (0.8 + 0.2 * 0.01) = 0.997506
        expected_absorbing_probs = jnp.array([
            [0, 0, 0.3, 0.7],
            [0, 0, 0, 0.81],
            [1, 0, 0, 0],
            [0, 0, 0, 0.997506],
        ])
        absorbing_probs = automaton_builder.all_nodes_absorbing_solve(
            builder,
            tmat,
            variant_weights,
            jnp.pad(jnp.ones([4, 1]), [(0, 0), (0, 1)]),
            steps=1000,
            backtrack_fails_prob=0.01)

        jax.test_util.check_close(absorbing_probs, expected_absorbing_probs)
예제 #8
0
    def test_routing_reduce_correct(self, reduction):
        """Compare JAX implementations to a (slow but correct) iterative one."""
        n_variants = 2
        n_states = 4

        def make_range_shaped(shape):
            return np.arange(np.prod(shape)).reshape(shape).astype("float32")

        schema = self.build_simple_schema()
        builder = automaton_builder.AutomatonBuilder(schema)
        routing_params = automaton_builder.RoutingParams(
            move=make_range_shaped([
                n_variants,
                len(builder.in_out_route_types),
                n_states,
                n_states,
            ]),
            special=make_range_shaped([
                n_variants,
                len(builder.in_route_types),
                n_states,
                len(builder.special_actions),
            ]),
        )

        # Compute aggregates with JAX
        if reduction == "softmax":
            routing_aggregates = builder.routing_softmax(routing_params)
        else:
            routing_aggregates = builder.routing_reduce(routing_params,
                                                        reduction=reduction)
            routing_aggregates = jax.tree_multimap(
                lambda s, p: np.array(jnp.broadcast_to(s, p.shape)),
                routing_aggregates, routing_params)

        # Manual looping aggregates
        for variant in range(n_variants):
            for current_state in range(n_states):
                for in_route_type in builder.in_route_types:
                    # Compute aggregates
                    distn_vals = []
                    iroute_idx = builder.in_route_type_to_index[in_route_type]
                    for out_edge_type in schema[
                            in_route_type.node_type].out_edges:
                        ioroute_idx = builder.in_out_route_type_to_index[
                            automaton_builder.InOutRouteType(
                                in_route_type.node_type, in_route_type.in_edge,
                                out_edge_type)]
                        for next_state in range(n_states):
                            distn_vals.append(
                                routing_params.move[variant, ioroute_idx,
                                                    current_state, next_state])

                    for action_idx in range(len(builder.special_actions)):
                        distn_vals.append(routing_params.special[variant,
                                                                 iroute_idx,
                                                                 current_state,
                                                                 action_idx])

                    if reduction == "sum":
                        distn_aggregate = [sum(distn_vals)] * len(distn_vals)
                    elif reduction == "max":
                        distn_aggregate = [max(distn_vals)] * len(distn_vals)
                    elif reduction == "softmax":
                        distn_aggregate = list(
                            jax.nn.softmax(jnp.array(distn_vals)))
                    else:
                        raise ValueError(f"Invalid reduction {reduction}")

                    i = 0
                    # Check them with the JAX version
                    for out_edge_type in schema[
                            in_route_type.node_type].out_edges:
                        ioroute_idx = builder.in_out_route_type_to_index[
                            automaton_builder.InOutRouteType(
                                in_route_type.node_type, in_route_type.in_edge,
                                out_edge_type)]
                        for next_state in range(n_states):
                            np.testing.assert_allclose(
                                routing_aggregates.move[variant, ioroute_idx,
                                                        current_state,
                                                        next_state],
                                distn_aggregate[i],
                                rtol=1e-6)
                            i += 1

                    for action_idx in range(len(builder.special_actions)):
                        np.testing.assert_allclose(
                            routing_aggregates.special[variant, iroute_idx,
                                                       current_state,
                                                       action_idx],
                            distn_aggregate[i],
                            rtol=1e-6)
                        i += 1