Ejemplo n.º 1
0
    def test_constructor_information_removing_mappings(self):
        builder = automaton_builder.AutomatonBuilder(
            self.build_simple_schema())

        # Check consistency of information-removing mappings with the
        # corresponding pairs of lists.
        for in_out_route_type in builder.in_out_route_types:
            in_route_type = automaton_builder.InRouteType(
                in_out_route_type.node_type, in_out_route_type.in_edge)
            self.assertEqual(
                builder.in_out_route_to_in_route[
                    builder.in_out_route_type_to_index[in_out_route_type]],
                builder.in_route_type_to_index[in_route_type])

        for in_route_type in builder.in_route_types:
            node_type = graph_types.NodeType(in_route_type.node_type)
            self.assertEqual(
                builder.in_route_to_node_type[
                    builder.in_route_type_to_index[in_route_type]],
                builder.node_type_to_index[node_type])

        for in_out_route_type in builder.in_out_route_types:
            in_route_type = automaton_builder.InRouteType(
                in_out_route_type.node_type, in_out_route_type.in_edge)
            self.assertEqual(
                builder.in_out_route_to_in_route[
                    builder.in_out_route_type_to_index[in_out_route_type]],
                builder.in_route_type_to_index[in_route_type])
Ejemplo n.º 2
0
    def test_constructor_actions_nodes_routes(self):
        builder = automaton_builder.AutomatonBuilder(
            self.build_simple_schema(), with_backtrack=False, with_fail=True)

        self.assertEqual(
            set(builder.special_actions), {
                automaton_builder.SpecialActions.FINISH,
                automaton_builder.SpecialActions.FAIL
            })

        self.assertEqual(
            set(builder.node_types),
            {graph_types.NodeType("a"),
             graph_types.NodeType("b")})

        self.assertEqual(
            set(builder.in_route_types), {
                automaton_builder.InRouteType(
                    graph_types.NodeType("a"),
                    automaton_builder.SOURCE_INITIAL),
                automaton_builder.InRouteType(graph_types.NodeType("a"),
                                              graph_types.InEdgeType("ai_0")),
                automaton_builder.InRouteType(graph_types.NodeType("a"),
                                              graph_types.InEdgeType("ai_1")),
                automaton_builder.InRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL),
                automaton_builder.InRouteType(graph_types.NodeType("b"),
                                              graph_types.InEdgeType("bi_0")),
            })

        self.assertEqual(
            set(builder.in_out_route_types), {
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"),
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"), graph_types.InEdgeType("ai_1"),
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("bo_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("bo_1")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
                    graph_types.OutEdgeType("bo_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
                    graph_types.OutEdgeType("bo_1")),
            })
  def test_routing_gates_to_probs(self):
    builder = automaton_builder.AutomatonBuilder(self.build_simple_schema())

    # [variants, in_out_routes, fsm_states, fsm_states]
    # [variants, in_routes, fsm_states]
    move_gates = np.full([3, len(builder.in_out_route_types), 2, 2], 0.5)
    accept_gates = np.full([3, len(builder.in_route_types), 2], 0.5)
    backtrack_gates = np.full([3, len(builder.in_route_types), 2], 0.5)

    # Set one distribution to sum to more than 1.
    idx_d1_move1 = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
            graph_types.OutEdgeType("bo_0"))]
    move_gates[0, idx_d1_move1, 0, :] = [.2, .3]
    idx_d1_move2 = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
            graph_types.OutEdgeType("bo_1"))]
    move_gates[0, idx_d1_move2, 0, :] = [.4, .5]
    idx_d1_special = builder.in_route_type_to_index[
        automaton_builder.InRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"))]
    accept_gates[0, idx_d1_special, 0] = .6
    backtrack_gates[0, idx_d1_special, 0] = .3

    # Set another to sum to less than 1.
    idx_d2_move = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"),
            graph_types.OutEdgeType("ao_0"))]
    move_gates[2, idx_d2_move, 1, :] = [.1, .2]
    idx_d2_special = builder.in_route_type_to_index[
        automaton_builder.InRouteType(
            graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"))]
    accept_gates[2, idx_d2_special, 1] = .3
    backtrack_gates[2, idx_d2_special, 1] = .75

    routing_gates = automaton_builder.RoutingGateParams(
        move_gates=jax.scipy.special.logit(move_gates),
        accept_gates=jax.scipy.special.logit(accept_gates),
        backtrack_gates=jax.scipy.special.logit(backtrack_gates))
    routing_probs = builder.routing_gates_to_probs(routing_gates)

    # Check probabilities for first distribution: should divide evenly.
    np.testing.assert_allclose(routing_probs.move[0, idx_d1_move1, 0, :],
                               np.array([.2, .3]) / 2.0)
    np.testing.assert_allclose(routing_probs.move[0, idx_d1_move2, 0, :],
                               np.array([.4, .5]) / 2.0)
    np.testing.assert_allclose(routing_probs.special[0, idx_d1_special, 0, :],
                               np.array([.6, 0, 0]) / 2.0)

    # Check probabilities for second distribution: should assign remainder to
    # backtrack and fail.
    np.testing.assert_allclose(routing_probs.move[2, idx_d2_move, 1, :],
                               np.array([.1, .2]))
    np.testing.assert_allclose(routing_probs.special[2, idx_d2_special, 1, :],
                               np.array([.3, .3, .1]))