def test_one_node_particle_estimate_padding(self): schema, graph = self.build_doubly_linked_list_graph(4) builder = automaton_builder.AutomatonBuilder(schema) enc_graph, enc_meta = builder.encode_graph(graph) enc_graph_padded = automaton_builder.EncodedGraph( initial_to_in_tagged=enc_graph.initial_to_in_tagged.pad_nonzeros(64), initial_to_special=jax_util.pad_to(enc_graph.initial_to_special, 64), in_tagged_to_in_tagged=( enc_graph.in_tagged_to_in_tagged.pad_nonzeros(64)), in_tagged_to_special=(jax_util.pad_to(enc_graph.in_tagged_to_special, 64)), in_tagged_node_indices=(jax_util.pad_to( enc_graph.in_tagged_node_indices, 64))) enc_meta_padded = automaton_builder.EncodedGraphMetadata( num_nodes=64, num_input_tagged_nodes=64) variant_weights = jnp.full([64, 5], 0.2) routing_params = automaton_builder.RoutingParams( move=jnp.full([5, 6, 2, 2], 0.2), special=jnp.full([5, 3, 2, 3], 0.2)) tmat = builder.build_transition_matrix(routing_params, enc_graph_padded, enc_meta_padded) outs = automaton_sampling.one_node_particle_estimate( builder, tmat, variant_weights, start_machine_state=jnp.array([1., 0.]), node_index=0, steps=100, num_rollouts=100, max_possible_transitions=2, num_valid_nodes=enc_meta.num_nodes, rng=jax.random.PRNGKey(0)) self.assertEqual(outs.shape, (64,)) self.assertTrue(jnp.all(outs[:enc_meta.num_nodes] > 0)) self.assertTrue(jnp.all(outs[enc_meta.num_nodes:] == 0))
def go(variant_weights, routing_params, eps): variant_weights = ((1 - eps) * variant_weights + eps * jnp.ones_like(variant_weights) / 5) routing_params = automaton_builder.RoutingParams( move=((1 - eps) * routing_params.move + eps * jnp.ones_like(routing_params.move) / 5), special=((1 - eps) * routing_params.special + eps * jnp.ones_like(routing_params.special) / 5), ) variant_weights = variant_weights / jnp.sum( variant_weights, axis=-1, keepdims=True) routing_params_sum = builder.routing_reduce(routing_params, "sum") routing_params = jax.tree_multimap(jax.lax.div, routing_params, routing_params_sum) tmat = builder.build_transition_matrix(routing_params, enc_graph, enc_meta) return automaton_sampling.all_nodes_particle_estimate( builder, tmat, variant_weights, jnp.pad(jnp.ones([4, 1]), [(0, 0), (0, 1)]), steps=100, rng=jax.random.PRNGKey(1), num_rollouts=10000, max_possible_transitions=2, # only two edges to leave each node num_valid_nodes=enc_meta.num_nodes, )
def test_transition_sentinel_integers(self): """Test that the transition matrix puts each element in the right place.""" schema, test_graph = self.build_loop_graph() builder = automaton_builder.AutomatonBuilder(schema) encoded_graph, graph_meta = builder.encode_graph(test_graph, as_jax=False) # Apply to some sentinel integers to check correct indexing (with only one # variant and state, since indexing doesn't use those) # Each integer is of the form XYZ where # X = {a:1, b:2}[node_type] # Y = {initial:9, i0:0, i1:1}[in_edge] # Z = {o0:0, o1:1, o2:2, finish:3, backtrack:4, fail:5}[action] sentinel_routing_params = automaton_builder.RoutingParams( move=jnp.array([[100, 101, 110, 111, 190, 191], [200, 201, 202, 290, 291, 292]]).reshape( (1, 12, 1, 1)), special=jnp.array([ [103, 104, 105], [113, 114, 115], [193, 194, 195], [203, 204, 205], [293, 294, 295], ]).reshape((1, 5, 1, 3))) range_transition_matrix = builder.build_transition_matrix( sentinel_routing_params, encoded_graph, graph_meta, ).concatenated_transitions() self.assertEqual(range_transition_matrix.shape, (1, 4 + 6, 1, 6 * 1 + 3)) # pyformat: disable # pylint: disable=bad-continuation,bad-whitespace,g-inline-comment-too-close expected = np.array([ # a0i0 a0i1 a1i0 a1i1 b0i0 b1i0 specials < next [ # | | | | | | |---------| current V [[0, 0, 191, 0, 0, 190, 193, 194, 195]], # ┬ a0 [[0, 190, 0, 0, 191, 0, 193, 194, 195]], # | a1 [[0, 0, 0, 290, 292 / 2, 291 + 292 / 2, 293, 294, 295]], # | b0 [[291, 0, 0, 0, 290 + 292 / 2, 292 / 2, 293, 294, 295]], # ┴ b1 [[0, 0, 101, 0, 0, 100, 103, 104, 105]], # ┬ a0i0 [[0, 0, 111, 0, 0, 110, 113, 114, 115]], # | a0i1 [[0, 100, 0, 0, 101, 0, 103, 104, 105]], # | a1i0 [[0, 110, 0, 0, 111, 0, 113, 114, 115]], # | a1i1 [[0, 0, 0, 200, 202 / 2, 201 + 202 / 2, 203, 204, 205]], # | b0i0 [[201, 0, 0, 0, 200 + 202 / 2, 202 / 2, 203, 204, 205]], # ┴ b0i1 ] ]) # pyformat: enable # pylint: enable=bad-continuation,bad-whitespace,g-inline-comment-too-close np.testing.assert_allclose(range_transition_matrix, expected)
def test_transition_all_ones(self): """Test the transition matrix of an all-ones routing parameter vector.""" schema, test_graph = self.build_loop_graph() builder = automaton_builder.AutomatonBuilder(schema) encoded_graph, graph_meta = builder.encode_graph(test_graph, as_jax=False) # The transition matrix for an all-ones routing params should be a # (weighted) directed adjacency matrix. We use 3 variants, 2 states. ones_routing_params = automaton_builder.RoutingParams( move=jnp.ones([3, 12, 2, 2]), special=jnp.ones([3, 5, 2, 3])) ones_transition_matrix = builder.build_transition_matrix( ones_routing_params, encoded_graph, graph_meta, ).concatenated_transitions() self.assertEqual(ones_transition_matrix.shape, (3, 4 + 6, 2, 6 * 2 + 3)) # pyformat: disable # pylint: disable=bad-continuation,bad-whitespace,g-inline-comment-too-close expected = np.array([ # a0i0 a0i1 a1i0 a1i1 b0i0 b1i0 specials < next [ #|------|-------|-------|-------|-------|-------| |--------| current V [[ 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1], # ┬ a0 [ 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1]], # | [[ 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1], # | a1 [ 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1]], # | [[ 0, 0, 0, 0, 0, 0, 1, 1,0.5,0.5,1.5,1.5, 1, 1, 1], # | b0 [ 0, 0, 0, 0, 0, 0, 1, 1,0.5,0.5,1.5,1.5, 1, 1, 1]], # | [[ 1, 1, 0, 0, 0, 0, 0, 0,1.5,1.5,0.5,0.5, 1, 1, 1], # | b1 [ 1, 1, 0, 0, 0, 0, 0, 0,1.5,1.5,0.5,0.5, 1, 1, 1]], # ┴ [[ 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1], # ┬ a0i0 [ 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1]], # | [[ 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1], # | a0i1 [ 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1]], # | [[ 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1], # | a1i0 [ 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1]], # | [[ 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1], # | a1i1 [ 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1]], # | [[ 0, 0, 0, 0, 0, 0, 1, 1,0.5,0.5,1.5,1.5, 1, 1, 1], # | b0i0 [ 0, 0, 0, 0, 0, 0, 1, 1,0.5,0.5,1.5,1.5, 1, 1, 1]], # | [[ 1, 1, 0, 0, 0, 0, 0, 0,1.5,1.5,0.5,0.5, 1, 1, 1], # | b0i1 [ 1, 1, 0, 0, 0, 0, 0, 0,1.5,1.5,0.5,0.5, 1, 1, 1]], # ┴ ] ] * 3) # pyformat: enable # pylint: enable=bad-continuation,bad-whitespace,g-inline-comment-too-close np.testing.assert_allclose(ones_transition_matrix, expected)
def test_all_nodes_particle_estimate(self): schema, graph = self.build_doubly_linked_list_graph(4) builder = automaton_builder.AutomatonBuilder(schema) enc_graph, enc_meta = builder.encode_graph(graph) # We set up the automaton with 5 variants and 2 states, but only use the # first state, to make sure that variants and states are interleaved # correctly. # Variant 0: move forward # Variant 1: move backward # Variant 2: finish # Variant 3: restart # Variant 4: fail variant_weights = jnp.array([ # From node 0, go forward. [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [.7, 0, .3, 0, 0], [0, 0, 1, 0, 0]], # From node 1, go backward with small failure probabilities. [[0, 0.9, 0, 0, 0.1], [0, 0.9, 0, 0, 0.1], [.7, 0, .3, 0, 0], [0, 0, 1, 0, 0]], # Node 2 bounces around and ultimately accepts on node 0. [[0.9, 0, 0.1, 0, 0], [0, 1, 0, 0, 0], [0.5, 0.5, 0, 0, 0], [0, 1, 0, 0, 0]], # Node 3 immediately accepts, or restarts after 0 or 1 steps. [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0.1, 0.8, 0.1, 0]], ]) routing_params = automaton_builder.RoutingParams( move=jnp.broadcast_to( jnp.pad( jnp.array([ [1., 0., 1., 0., 1., 0.], [0., 1., 0., 1., 0., 1.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], ]).reshape([5, 6, 1, 1]), [(0, 0), (0, 0), (0, 0), (0, 1)]), [5, 6, 2, 2]), special=jnp.broadcast_to( jnp.array([ [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]], [[0., 1., 0.], [0., 1., 0.], [0., 1., 0.]], [[0., 0., 1.], [0., 0., 1.], [0., 0., 1.]], ]).reshape([5, 3, 1, 3]), [5, 3, 2, 3])) @jax.jit def go(variant_weights, routing_params, eps): variant_weights = ((1 - eps) * variant_weights + eps * jnp.ones_like(variant_weights) / 5) routing_params = automaton_builder.RoutingParams( move=((1 - eps) * routing_params.move + eps * jnp.ones_like(routing_params.move) / 5), special=((1 - eps) * routing_params.special + eps * jnp.ones_like(routing_params.special) / 5), ) variant_weights = variant_weights / jnp.sum( variant_weights, axis=-1, keepdims=True) routing_params_sum = builder.routing_reduce(routing_params, "sum") routing_params = jax.tree_multimap(jax.lax.div, routing_params, routing_params_sum) tmat = builder.build_transition_matrix(routing_params, enc_graph, enc_meta) return automaton_sampling.all_nodes_particle_estimate( builder, tmat, variant_weights, jnp.pad(jnp.ones([4, 1]), [(0, 0), (0, 1)]), steps=100, rng=jax.random.PRNGKey(1), num_rollouts=10000, max_possible_transitions=2, # only two edges to leave each node num_valid_nodes=enc_meta.num_nodes, ) # Absorbing probs follow the paths described above. expected_absorbing_probs = jnp.array([ [0, 0, 0.3, 0.7], [0, 0, 0, 0.81], [1, 0, 0, 0], [0, 0, 0, 1], ]) particle_probs = go(variant_weights, routing_params, eps=0) # With 10000 rollouts we expect a standard deviation of up to # sqrt(0.3*0.7/10000) ~= 5e-3. Check that we are within 2 sigma. np.testing.assert_allclose(particle_probs, expected_absorbing_probs, atol=2 * 5e-3)
def test_unroll_and_aggregate(self): schema, graph = self.build_doubly_linked_list_graph(4) builder = automaton_builder.AutomatonBuilder(schema) enc_graph, enc_meta = builder.encode_graph(graph) # Three variants, but all the same, just to check shape consistency variant_weights = jnp.broadcast_to(jnp.array([0.7, 0.3, 0.]), [4, 3]) # In state 0: keep moving in the current direction with prob 0.1, swap # directions and states with prob 0.9 (except init, which is a special case) # In state 1: take a special action routing_params = automaton_builder.RoutingParams( move=jnp.broadcast_to( jnp.array([ # from next, to next [[0.0, 0.9], [0.0, 0.0]], # from next, to prev [[0.1, 0.0], [0.0, 0.0]], # from prev, to next [[0.1, 0.0], [0.0, 0.0]], # from next, to prev [[0.0, 0.9], [0.0, 0.0]], # from init, to next [[0.1, 0.0], [0.0, 0.0]], # from init, to prev [[0.1, 0.0], [0.0, 0.0]], ]), [3, 6, 2, 2]), special=jnp.broadcast_to( jnp.array([ # from next [[0, 0, 0], [0.2, 0.3, 0.5]], # from prev [[0, 0, 0], [0.5, 0.2, 0.3]], # from init [[0.1, 0.3, 0.4], [0.0, 0.0, 1.0]], ]), [3, 3, 2, 3])) tmat = builder.build_transition_matrix(routing_params, enc_graph, enc_meta) unrolled = automaton_builder.unroll_chain_steps(builder, tmat, variant_weights, jnp.array([1., 0.]), node_index=0, steps=6) expected_initial_special = np.array([0.1, 0.3, 0.4]) np.testing.assert_allclose(unrolled["initial_special"], expected_initial_special) # pyformat: disable # pylint: disable=bad-whitespace expected_in_tagged_states = np.array([ # In-tagged node key: # 0 from 1, 0 from 3, 1 from 2, 1 from 0, # 2 from 3, 2 from 1, 3 from 0, 3 from 2 # State key: [prob of being in state 0, prob of being in state 1] # -- First step from initial -- [[0, 0], [0, 0], [0, 0], [1e-1, 0], [0, 0], [0, 0], [1e-1, 0], [0, 0]], # -- Second step -- [[0, 9e-2], [0, 9e-2], [0, 0], [0, 0], [1e-2, 0], [1e-2, 0], [0, 0], [0, 0]], # ----------------- [[0, 0], [0, 0], [1e-3, 9e-3], [0, 0], [0, 0], [0, 0], [0, 0], [1e-3, 9e-3]], # ----------------- [[1e-4, 0], [1e-4, 0], [0, 0], [0, 0], [0, 9e-4], [0, 9e-4], [0, 0], [0, 0]], # ----------------- [[0, 0], [0, 0], [0, 0], [1e-5, 9e-5], [0, 0], [0, 0], [1e-5, 9e-5], [0, 0]], # ----------------- [[0, 9e-6], [0, 9e-6], [0, 0], [0, 0], [1e-6, 0], [1e-6, 0], [0, 0], [0, 0]], ]) np.testing.assert_allclose(unrolled["in_tagged_states"], expected_in_tagged_states, atol=1e-8) expected_in_tagged_special = np.array([ # In-tagged node key: # 0 from 1, 0 from 3 # 1 from 2, 1 from 0, # 2 from 3, 2 from 1 # 3 from 0, 3 from 2 # Action key: [finish, backtrack, fail] (cumulative) # -- First step from initial -- # (no special actions because we just left the initial node) [[0, 0, 0]] * 8, # -- Second step -- # (no special actions yet because everything was in state 0) [[0, 0, 0]] * 8, # ----------------- [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], # ----------------- [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2], [1.8e-3, 2.7e-3, 4.5e-3], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [4.5e-3, 1.8e-3, 2.7e-3]], # ----------------- [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2], [1.8e-3, 2.7e-3, 4.5e-3], [0, 0, 0], [1.8e-4, 2.7e-4, 4.5e-4], [4.5e-4, 1.8e-4, 2.7e-4], [0, 0, 0], [4.5e-3, 1.8e-3, 2.7e-3]], # ----------------- [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2], [1.8e-3, 2.7e-3, 4.5e-3], [4.5e-5, 1.8e-5, 2.7e-5], [1.8e-4, 2.7e-4, 4.5e-4], [4.5e-4, 1.8e-4, 2.7e-4], [1.8e-5, 2.7e-5, 4.5e-5], [4.5e-3, 1.8e-3, 2.7e-3]], ]) np.testing.assert_allclose(unrolled["in_tagged_special"], expected_in_tagged_special, atol=1e-8) # pyformat: enable # pylint: enable=bad-whitespace unrolled_combined = automaton_builder.aggregate_unrolled_per_node( unrolled, 0, 0, tmat, enc_meta) expected_unrolled = np.array([ # "Zeroth" step: at initial node, no specials have happened yet [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], # The other steps are either the states from above, or sums of entries # in initial_special and in_tagged_special [[0, 0, 0.1, 0.3, 0.4], [0.1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0.1, 0, 0, 0, 0]], # ----------------- [[0, 0.18, 0.1, 0.3, 0.4], [0, 0, 0, 0, 0], [0.02, 0, 0, 0, 0], [0, 0, 0, 0, 0]], # ----------------- [[0, 0, 0.163, 0.34500003, 0.472], [0.001, 0.009, 0, 0, 0], [0, 0, 0, 0, 0], [0.001, 0.009, 0, 0, 0]], # ----------------- [[0.0002, 0, 0.163, 0.34500003, 0.472], [0, 0, 0.0018, 0.0027, 0.0045], [0, 0.0018, 0, 0, 0], [0, 0, 0.0045, 0.0018, 0.0027]], # ----------------- [[0, 0, 0.163, 0.34500003, 0.472], [1e-05, 9e-05, 0.0018, 0.0027, 0.0045], [0, 0, 0.00063, 0.00045, 0.00072], [1e-05, 9e-05, 0.0045, 0.0018, 0.0027]], # ----------------- [[0, 1.8e-05, 0.163, 0.34500003, 0.472], [0, 0, 0.001845, 0.002718, 0.004527], [2e-06, 0, 0.00063, 0.00045, 0.00072], [0, 0, 0.004518, 0.001827, 0.002745]], ]) np.testing.assert_allclose(unrolled_combined, expected_unrolled, atol=1e-8)
def test_all_nodes_absorbing_solve(self): schema, graph = self.build_doubly_linked_list_graph(4) builder = automaton_builder.AutomatonBuilder(schema) enc_graph, enc_meta = builder.encode_graph(graph) # We set up the automaton with 5 variants and 2 states, but only use the # first state, to make sure that variants and states are interleaved # correctly. # Variant 0: move forward # Variant 1: move backward # Variant 2: finish # Variant 3: restart # Variant 4: fail variant_weights = jnp.array([ # From node 0, go forward. [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [.7, 0, .3, 0, 0], [0, 0, 1, 0, 0]], # From node 1, go backward with small failure probabilities. [[0, 0.9, 0, 0, 0.1], [0, 0.9, 0, 0, 0.1], [.7, 0, .3, 0, 0], [0, 0, 1, 0, 0]], # Node 2 bounces around and ultimately accepts on node 0. [[0.9, 0, 0.1, 0, 0], [0, 1, 0, 0, 0], [0.5, 0.5, 0, 0, 0], [0, 1, 0, 0, 0]], # Node 3 immediately accepts, or restarts after 0 or 1 steps. [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0.1, 0.8, 0.1, 0]], ]) routing_params = automaton_builder.RoutingParams( move=jnp.pad( jnp.array([ [1., 0., 1., 0., 1., 0.], [0., 1., 0., 1., 0., 1.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0.], ]).reshape([5, 6, 1, 1]), [(0, 0), (0, 0), (0, 1), (0, 1)]), special=jnp.pad( jnp.array([ [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]], [[0., 1., 0.], [0., 1., 0.], [0., 1., 0.]], [[0., 0., 1.], [0., 0., 1.], [0., 0., 1.]], ]).reshape([5, 3, 1, 3]), [(0, 0), (0, 0), (0, 1), (0, 0)])) tmat = builder.build_transition_matrix(routing_params, enc_graph, enc_meta) # Absorbing probs follow the paths described above. # Note that when starting at node 3, with probability 0.2 the automaton # tries to backtrack, but with probability 0.2 * 0.01 backtracking fails # (as specified by backtrack_fails_prob) and thus the total absorbing # probability is 0.8 / (0.8 + 0.2 * 0.01) = 0.997506 expected_absorbing_probs = jnp.array([ [0, 0, 0.3, 0.7], [0, 0, 0, 0.81], [1, 0, 0, 0], [0, 0, 0, 0.997506], ]) absorbing_probs = automaton_builder.all_nodes_absorbing_solve( builder, tmat, variant_weights, jnp.pad(jnp.ones([4, 1]), [(0, 0), (0, 1)]), steps=1000, backtrack_fails_prob=0.01) jax.test_util.check_close(absorbing_probs, expected_absorbing_probs)
def test_routing_reduce_correct(self, reduction): """Compare JAX implementations to a (slow but correct) iterative one.""" n_variants = 2 n_states = 4 def make_range_shaped(shape): return np.arange(np.prod(shape)).reshape(shape).astype("float32") schema = self.build_simple_schema() builder = automaton_builder.AutomatonBuilder(schema) routing_params = automaton_builder.RoutingParams( move=make_range_shaped([ n_variants, len(builder.in_out_route_types), n_states, n_states, ]), special=make_range_shaped([ n_variants, len(builder.in_route_types), n_states, len(builder.special_actions), ]), ) # Compute aggregates with JAX if reduction == "softmax": routing_aggregates = builder.routing_softmax(routing_params) else: routing_aggregates = builder.routing_reduce(routing_params, reduction=reduction) routing_aggregates = jax.tree_multimap( lambda s, p: np.array(jnp.broadcast_to(s, p.shape)), routing_aggregates, routing_params) # Manual looping aggregates for variant in range(n_variants): for current_state in range(n_states): for in_route_type in builder.in_route_types: # Compute aggregates distn_vals = [] iroute_idx = builder.in_route_type_to_index[in_route_type] for out_edge_type in schema[ in_route_type.node_type].out_edges: ioroute_idx = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( in_route_type.node_type, in_route_type.in_edge, out_edge_type)] for next_state in range(n_states): distn_vals.append( routing_params.move[variant, ioroute_idx, current_state, next_state]) for action_idx in range(len(builder.special_actions)): distn_vals.append(routing_params.special[variant, iroute_idx, current_state, action_idx]) if reduction == "sum": distn_aggregate = [sum(distn_vals)] * len(distn_vals) elif reduction == "max": distn_aggregate = [max(distn_vals)] * len(distn_vals) elif reduction == "softmax": distn_aggregate = list( jax.nn.softmax(jnp.array(distn_vals))) else: raise ValueError(f"Invalid reduction {reduction}") i = 0 # Check them with the JAX version for out_edge_type in schema[ in_route_type.node_type].out_edges: ioroute_idx = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( in_route_type.node_type, in_route_type.in_edge, out_edge_type)] for next_state in range(n_states): np.testing.assert_allclose( routing_aggregates.move[variant, ioroute_idx, current_state, next_state], distn_aggregate[i], rtol=1e-6) i += 1 for action_idx in range(len(builder.special_actions)): np.testing.assert_allclose( routing_aggregates.special[variant, iroute_idx, current_state, action_idx], distn_aggregate[i], rtol=1e-6) i += 1