def test_constructor_information_removing_mappings(self): builder = automaton_builder.AutomatonBuilder( self.build_simple_schema()) # Check consistency of information-removing mappings with the # corresponding pairs of lists. for in_out_route_type in builder.in_out_route_types: in_route_type = automaton_builder.InRouteType( in_out_route_type.node_type, in_out_route_type.in_edge) self.assertEqual( builder.in_out_route_to_in_route[ builder.in_out_route_type_to_index[in_out_route_type]], builder.in_route_type_to_index[in_route_type]) for in_route_type in builder.in_route_types: node_type = graph_types.NodeType(in_route_type.node_type) self.assertEqual( builder.in_route_to_node_type[ builder.in_route_type_to_index[in_route_type]], builder.node_type_to_index[node_type]) for in_out_route_type in builder.in_out_route_types: in_route_type = automaton_builder.InRouteType( in_out_route_type.node_type, in_out_route_type.in_edge) self.assertEqual( builder.in_out_route_to_in_route[ builder.in_out_route_type_to_index[in_out_route_type]], builder.in_route_type_to_index[in_route_type])
def test_constructor_actions_nodes_routes(self): builder = automaton_builder.AutomatonBuilder( self.build_simple_schema(), with_backtrack=False, with_fail=True) self.assertEqual( set(builder.special_actions), { automaton_builder.SpecialActions.FINISH, automaton_builder.SpecialActions.FAIL }) self.assertEqual( set(builder.node_types), {graph_types.NodeType("a"), graph_types.NodeType("b")}) self.assertEqual( set(builder.in_route_types), { automaton_builder.InRouteType( graph_types.NodeType("a"), automaton_builder.SOURCE_INITIAL), automaton_builder.InRouteType(graph_types.NodeType("a"), graph_types.InEdgeType("ai_0")), automaton_builder.InRouteType(graph_types.NodeType("a"), graph_types.InEdgeType("ai_1")), automaton_builder.InRouteType( graph_types.NodeType("b"), automaton_builder.SOURCE_INITIAL), automaton_builder.InRouteType(graph_types.NodeType("b"), graph_types.InEdgeType("bi_0")), }) self.assertEqual( set(builder.in_out_route_types), { automaton_builder.InOutRouteType( graph_types.NodeType("a"), automaton_builder.SOURCE_INITIAL, graph_types.OutEdgeType("ao_0")), automaton_builder.InOutRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"), graph_types.OutEdgeType("ao_0")), automaton_builder.InOutRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_1"), graph_types.OutEdgeType("ao_0")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), automaton_builder.SOURCE_INITIAL, graph_types.OutEdgeType("bo_0")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), automaton_builder.SOURCE_INITIAL, graph_types.OutEdgeType("bo_1")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_0")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_1")), })
def test_routing_gates_to_probs(self): builder = automaton_builder.AutomatonBuilder(self.build_simple_schema()) # [variants, in_out_routes, fsm_states, fsm_states] # [variants, in_routes, fsm_states] move_gates = np.full([3, len(builder.in_out_route_types), 2, 2], 0.5) accept_gates = np.full([3, len(builder.in_route_types), 2], 0.5) backtrack_gates = np.full([3, len(builder.in_route_types), 2], 0.5) # Set one distribution to sum to more than 1. idx_d1_move1 = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_0"))] move_gates[0, idx_d1_move1, 0, :] = [.2, .3] idx_d1_move2 = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_1"))] move_gates[0, idx_d1_move2, 0, :] = [.4, .5] idx_d1_special = builder.in_route_type_to_index[ automaton_builder.InRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"))] accept_gates[0, idx_d1_special, 0] = .6 backtrack_gates[0, idx_d1_special, 0] = .3 # Set another to sum to less than 1. idx_d2_move = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"), graph_types.OutEdgeType("ao_0"))] move_gates[2, idx_d2_move, 1, :] = [.1, .2] idx_d2_special = builder.in_route_type_to_index[ automaton_builder.InRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"))] accept_gates[2, idx_d2_special, 1] = .3 backtrack_gates[2, idx_d2_special, 1] = .75 routing_gates = automaton_builder.RoutingGateParams( move_gates=jax.scipy.special.logit(move_gates), accept_gates=jax.scipy.special.logit(accept_gates), backtrack_gates=jax.scipy.special.logit(backtrack_gates)) routing_probs = builder.routing_gates_to_probs(routing_gates) # Check probabilities for first distribution: should divide evenly. np.testing.assert_allclose(routing_probs.move[0, idx_d1_move1, 0, :], np.array([.2, .3]) / 2.0) np.testing.assert_allclose(routing_probs.move[0, idx_d1_move2, 0, :], np.array([.4, .5]) / 2.0) np.testing.assert_allclose(routing_probs.special[0, idx_d1_special, 0, :], np.array([.6, 0, 0]) / 2.0) # Check probabilities for second distribution: should assign remainder to # backtrack and fail. np.testing.assert_allclose(routing_probs.move[2, idx_d2_move, 1, :], np.array([.1, .2])) np.testing.assert_allclose(routing_probs.special[2, idx_d2_special, 1, :], np.array([.3, .3, .1]))