def test_routing_gates_to_probs(self): builder = automaton_builder.AutomatonBuilder(self.build_simple_schema()) # [variants, in_out_routes, fsm_states, fsm_states] # [variants, in_routes, fsm_states] move_gates = np.full([3, len(builder.in_out_route_types), 2, 2], 0.5) accept_gates = np.full([3, len(builder.in_route_types), 2], 0.5) backtrack_gates = np.full([3, len(builder.in_route_types), 2], 0.5) # Set one distribution to sum to more than 1. idx_d1_move1 = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_0"))] move_gates[0, idx_d1_move1, 0, :] = [.2, .3] idx_d1_move2 = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_1"))] move_gates[0, idx_d1_move2, 0, :] = [.4, .5] idx_d1_special = builder.in_route_type_to_index[ automaton_builder.InRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"))] accept_gates[0, idx_d1_special, 0] = .6 backtrack_gates[0, idx_d1_special, 0] = .3 # Set another to sum to less than 1. idx_d2_move = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"), graph_types.OutEdgeType("ao_0"))] move_gates[2, idx_d2_move, 1, :] = [.1, .2] idx_d2_special = builder.in_route_type_to_index[ automaton_builder.InRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"))] accept_gates[2, idx_d2_special, 1] = .3 backtrack_gates[2, idx_d2_special, 1] = .75 routing_gates = automaton_builder.RoutingGateParams( move_gates=jax.scipy.special.logit(move_gates), accept_gates=jax.scipy.special.logit(accept_gates), backtrack_gates=jax.scipy.special.logit(backtrack_gates)) routing_probs = builder.routing_gates_to_probs(routing_gates) # Check probabilities for first distribution: should divide evenly. np.testing.assert_allclose(routing_probs.move[0, idx_d1_move1, 0, :], np.array([.2, .3]) / 2.0) np.testing.assert_allclose(routing_probs.move[0, idx_d1_move2, 0, :], np.array([.4, .5]) / 2.0) np.testing.assert_allclose(routing_probs.special[0, idx_d1_special, 0, :], np.array([.6, 0, 0]) / 2.0) # Check probabilities for second distribution: should assign remainder to # backtrack and fail. np.testing.assert_allclose(routing_probs.move[2, idx_d2_move, 1, :], np.array([.1, .2])) np.testing.assert_allclose(routing_probs.special[2, idx_d2_special, 1, :], np.array([.3, .3, .1]))
def test_constructor_actions_nodes_routes(self): builder = automaton_builder.AutomatonBuilder( self.build_simple_schema(), with_backtrack=False, with_fail=True) self.assertEqual( set(builder.special_actions), { automaton_builder.SpecialActions.FINISH, automaton_builder.SpecialActions.FAIL }) self.assertEqual( set(builder.node_types), {graph_types.NodeType("a"), graph_types.NodeType("b")}) self.assertEqual( set(builder.in_route_types), { automaton_builder.InRouteType( graph_types.NodeType("a"), automaton_builder.SOURCE_INITIAL), automaton_builder.InRouteType(graph_types.NodeType("a"), graph_types.InEdgeType("ai_0")), automaton_builder.InRouteType(graph_types.NodeType("a"), graph_types.InEdgeType("ai_1")), automaton_builder.InRouteType( graph_types.NodeType("b"), automaton_builder.SOURCE_INITIAL), automaton_builder.InRouteType(graph_types.NodeType("b"), graph_types.InEdgeType("bi_0")), }) self.assertEqual( set(builder.in_out_route_types), { automaton_builder.InOutRouteType( graph_types.NodeType("a"), automaton_builder.SOURCE_INITIAL, graph_types.OutEdgeType("ao_0")), automaton_builder.InOutRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"), graph_types.OutEdgeType("ao_0")), automaton_builder.InOutRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_1"), graph_types.OutEdgeType("ao_0")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), automaton_builder.SOURCE_INITIAL, graph_types.OutEdgeType("bo_0")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), automaton_builder.SOURCE_INITIAL, graph_types.OutEdgeType("bo_1")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_0")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_1")), })
def test_routing_reduce_correct(self, reduction): """Compare JAX implementations to a (slow but correct) iterative one.""" n_variants = 2 n_states = 4 def make_range_shaped(shape): return np.arange(np.prod(shape)).reshape(shape).astype("float32") schema = self.build_simple_schema() builder = automaton_builder.AutomatonBuilder(schema) routing_params = automaton_builder.RoutingParams( move=make_range_shaped([ n_variants, len(builder.in_out_route_types), n_states, n_states, ]), special=make_range_shaped([ n_variants, len(builder.in_route_types), n_states, len(builder.special_actions), ]), ) # Compute aggregates with JAX if reduction == "softmax": routing_aggregates = builder.routing_softmax(routing_params) else: routing_aggregates = builder.routing_reduce(routing_params, reduction=reduction) routing_aggregates = jax.tree_multimap( lambda s, p: np.array(jnp.broadcast_to(s, p.shape)), routing_aggregates, routing_params) # Manual looping aggregates for variant in range(n_variants): for current_state in range(n_states): for in_route_type in builder.in_route_types: # Compute aggregates distn_vals = [] iroute_idx = builder.in_route_type_to_index[in_route_type] for out_edge_type in schema[ in_route_type.node_type].out_edges: ioroute_idx = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( in_route_type.node_type, in_route_type.in_edge, out_edge_type)] for next_state in range(n_states): distn_vals.append( routing_params.move[variant, ioroute_idx, current_state, next_state]) for action_idx in range(len(builder.special_actions)): distn_vals.append(routing_params.special[variant, iroute_idx, current_state, action_idx]) if reduction == "sum": distn_aggregate = [sum(distn_vals)] * len(distn_vals) elif reduction == "max": distn_aggregate = [max(distn_vals)] * len(distn_vals) elif reduction == "softmax": distn_aggregate = list( jax.nn.softmax(jnp.array(distn_vals))) else: raise ValueError(f"Invalid reduction {reduction}") i = 0 # Check them with the JAX version for out_edge_type in schema[ in_route_type.node_type].out_edges: ioroute_idx = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( in_route_type.node_type, in_route_type.in_edge, out_edge_type)] for next_state in range(n_states): np.testing.assert_allclose( routing_aggregates.move[variant, ioroute_idx, current_state, next_state], distn_aggregate[i], rtol=1e-6) i += 1 for action_idx in range(len(builder.special_actions)): np.testing.assert_allclose( routing_aggregates.special[variant, iroute_idx, current_state, action_idx], distn_aggregate[i], rtol=1e-6) i += 1