def test_routing_gates_to_probs(self):
    builder = automaton_builder.AutomatonBuilder(self.build_simple_schema())

    # [variants, in_out_routes, fsm_states, fsm_states]
    # [variants, in_routes, fsm_states]
    move_gates = np.full([3, len(builder.in_out_route_types), 2, 2], 0.5)
    accept_gates = np.full([3, len(builder.in_route_types), 2], 0.5)
    backtrack_gates = np.full([3, len(builder.in_route_types), 2], 0.5)

    # Set one distribution to sum to more than 1.
    idx_d1_move1 = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
            graph_types.OutEdgeType("bo_0"))]
    move_gates[0, idx_d1_move1, 0, :] = [.2, .3]
    idx_d1_move2 = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
            graph_types.OutEdgeType("bo_1"))]
    move_gates[0, idx_d1_move2, 0, :] = [.4, .5]
    idx_d1_special = builder.in_route_type_to_index[
        automaton_builder.InRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"))]
    accept_gates[0, idx_d1_special, 0] = .6
    backtrack_gates[0, idx_d1_special, 0] = .3

    # Set another to sum to less than 1.
    idx_d2_move = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"),
            graph_types.OutEdgeType("ao_0"))]
    move_gates[2, idx_d2_move, 1, :] = [.1, .2]
    idx_d2_special = builder.in_route_type_to_index[
        automaton_builder.InRouteType(
            graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"))]
    accept_gates[2, idx_d2_special, 1] = .3
    backtrack_gates[2, idx_d2_special, 1] = .75

    routing_gates = automaton_builder.RoutingGateParams(
        move_gates=jax.scipy.special.logit(move_gates),
        accept_gates=jax.scipy.special.logit(accept_gates),
        backtrack_gates=jax.scipy.special.logit(backtrack_gates))
    routing_probs = builder.routing_gates_to_probs(routing_gates)

    # Check probabilities for first distribution: should divide evenly.
    np.testing.assert_allclose(routing_probs.move[0, idx_d1_move1, 0, :],
                               np.array([.2, .3]) / 2.0)
    np.testing.assert_allclose(routing_probs.move[0, idx_d1_move2, 0, :],
                               np.array([.4, .5]) / 2.0)
    np.testing.assert_allclose(routing_probs.special[0, idx_d1_special, 0, :],
                               np.array([.6, 0, 0]) / 2.0)

    # Check probabilities for second distribution: should assign remainder to
    # backtrack and fail.
    np.testing.assert_allclose(routing_probs.move[2, idx_d2_move, 1, :],
                               np.array([.1, .2]))
    np.testing.assert_allclose(routing_probs.special[2, idx_d2_special, 1, :],
                               np.array([.3, .3, .1]))
Ejemplo n.º 2
0
    def test_constructor_actions_nodes_routes(self):
        builder = automaton_builder.AutomatonBuilder(
            self.build_simple_schema(), with_backtrack=False, with_fail=True)

        self.assertEqual(
            set(builder.special_actions), {
                automaton_builder.SpecialActions.FINISH,
                automaton_builder.SpecialActions.FAIL
            })

        self.assertEqual(
            set(builder.node_types),
            {graph_types.NodeType("a"),
             graph_types.NodeType("b")})

        self.assertEqual(
            set(builder.in_route_types), {
                automaton_builder.InRouteType(
                    graph_types.NodeType("a"),
                    automaton_builder.SOURCE_INITIAL),
                automaton_builder.InRouteType(graph_types.NodeType("a"),
                                              graph_types.InEdgeType("ai_0")),
                automaton_builder.InRouteType(graph_types.NodeType("a"),
                                              graph_types.InEdgeType("ai_1")),
                automaton_builder.InRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL),
                automaton_builder.InRouteType(graph_types.NodeType("b"),
                                              graph_types.InEdgeType("bi_0")),
            })

        self.assertEqual(
            set(builder.in_out_route_types), {
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"),
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"), graph_types.InEdgeType("ai_1"),
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("bo_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("bo_1")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
                    graph_types.OutEdgeType("bo_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
                    graph_types.OutEdgeType("bo_1")),
            })
Ejemplo n.º 3
0
    def test_routing_reduce_correct(self, reduction):
        """Compare JAX implementations to a (slow but correct) iterative one."""
        n_variants = 2
        n_states = 4

        def make_range_shaped(shape):
            return np.arange(np.prod(shape)).reshape(shape).astype("float32")

        schema = self.build_simple_schema()
        builder = automaton_builder.AutomatonBuilder(schema)
        routing_params = automaton_builder.RoutingParams(
            move=make_range_shaped([
                n_variants,
                len(builder.in_out_route_types),
                n_states,
                n_states,
            ]),
            special=make_range_shaped([
                n_variants,
                len(builder.in_route_types),
                n_states,
                len(builder.special_actions),
            ]),
        )

        # Compute aggregates with JAX
        if reduction == "softmax":
            routing_aggregates = builder.routing_softmax(routing_params)
        else:
            routing_aggregates = builder.routing_reduce(routing_params,
                                                        reduction=reduction)
            routing_aggregates = jax.tree_multimap(
                lambda s, p: np.array(jnp.broadcast_to(s, p.shape)),
                routing_aggregates, routing_params)

        # Manual looping aggregates
        for variant in range(n_variants):
            for current_state in range(n_states):
                for in_route_type in builder.in_route_types:
                    # Compute aggregates
                    distn_vals = []
                    iroute_idx = builder.in_route_type_to_index[in_route_type]
                    for out_edge_type in schema[
                            in_route_type.node_type].out_edges:
                        ioroute_idx = builder.in_out_route_type_to_index[
                            automaton_builder.InOutRouteType(
                                in_route_type.node_type, in_route_type.in_edge,
                                out_edge_type)]
                        for next_state in range(n_states):
                            distn_vals.append(
                                routing_params.move[variant, ioroute_idx,
                                                    current_state, next_state])

                    for action_idx in range(len(builder.special_actions)):
                        distn_vals.append(routing_params.special[variant,
                                                                 iroute_idx,
                                                                 current_state,
                                                                 action_idx])

                    if reduction == "sum":
                        distn_aggregate = [sum(distn_vals)] * len(distn_vals)
                    elif reduction == "max":
                        distn_aggregate = [max(distn_vals)] * len(distn_vals)
                    elif reduction == "softmax":
                        distn_aggregate = list(
                            jax.nn.softmax(jnp.array(distn_vals)))
                    else:
                        raise ValueError(f"Invalid reduction {reduction}")

                    i = 0
                    # Check them with the JAX version
                    for out_edge_type in schema[
                            in_route_type.node_type].out_edges:
                        ioroute_idx = builder.in_out_route_type_to_index[
                            automaton_builder.InOutRouteType(
                                in_route_type.node_type, in_route_type.in_edge,
                                out_edge_type)]
                        for next_state in range(n_states):
                            np.testing.assert_allclose(
                                routing_aggregates.move[variant, ioroute_idx,
                                                        current_state,
                                                        next_state],
                                distn_aggregate[i],
                                rtol=1e-6)
                            i += 1

                    for action_idx in range(len(builder.special_actions)):
                        np.testing.assert_allclose(
                            routing_aggregates.special[variant, iroute_idx,
                                                       current_state,
                                                       action_idx],
                            distn_aggregate[i],
                            rtol=1e-6)
                        i += 1