Exemple #1
0
    def test_initial_routing_params_with_noise(self):
        builder = automaton_builder.AutomatonBuilder(
            self.build_simple_schema())

        # Small amounts of noise shouldn't change parameters much
        initializer_kwargs = dict(num_fsm_states=3,
                                  num_variants=2,
                                  state_change_prob=0.2,
                                  move_prob=0.9)
        noiseless_params = builder.initialize_routing_params(
            key=None, noise_factor=0, **initializer_kwargs)
        eps_noise_params = builder.initialize_routing_params(
            key=jax.random.PRNGKey(1234),
            noise_factor=1e-6,
            **initializer_kwargs)

        np.testing.assert_allclose(noiseless_params.move,
                                   eps_noise_params.move,
                                   rtol=0.02)
        np.testing.assert_allclose(noiseless_params.special,
                                   eps_noise_params.special,
                                   rtol=0.02)

        # Even with more noise, should still be normalized
        noisy_params = builder.initialize_routing_params(
            key=jax.random.PRNGKey(1234),
            noise_factor=0.8,
            **initializer_kwargs)
        noisy_sums = builder.routing_reduce(noisy_params, "sum")

        np.testing.assert_allclose(noisy_sums.move, 1.0, rtol=1e-6)
        np.testing.assert_allclose(noisy_sums.special, 1.0, rtol=1e-6)
Exemple #2
0
    def test_constructor_information_removing_mappings(self):
        builder = automaton_builder.AutomatonBuilder(
            self.build_simple_schema())

        # Check consistency of information-removing mappings with the
        # corresponding pairs of lists.
        for in_out_route_type in builder.in_out_route_types:
            in_route_type = automaton_builder.InRouteType(
                in_out_route_type.node_type, in_out_route_type.in_edge)
            self.assertEqual(
                builder.in_out_route_to_in_route[
                    builder.in_out_route_type_to_index[in_out_route_type]],
                builder.in_route_type_to_index[in_route_type])

        for in_route_type in builder.in_route_types:
            node_type = graph_types.NodeType(in_route_type.node_type)
            self.assertEqual(
                builder.in_route_to_node_type[
                    builder.in_route_type_to_index[in_route_type]],
                builder.node_type_to_index[node_type])

        for in_out_route_type in builder.in_out_route_types:
            in_route_type = automaton_builder.InRouteType(
                in_out_route_type.node_type, in_out_route_type.in_edge)
            self.assertEqual(
                builder.in_out_route_to_in_route[
                    builder.in_out_route_type_to_index[in_out_route_type]],
                builder.in_route_type_to_index[in_route_type])
Exemple #3
0
    def initialize_routing_gates(self):
        """Just make sure that we can initialize routing gates."""
        builder = automaton_builder.AutomatonBuilder(
            self.build_simple_schema())

        # Noiseless
        noiseless_gates = builder.initialize_routing_gates(key=None,
                                                           logistic_noise=0,
                                                           num_fsm_states=3,
                                                           num_variants=2)
        self.assertEqual(noiseless_gates.move_gates.shape,
                         (2, len(builder.in_out_route_types), 3, 3))
        self.assertEqual(noiseless_gates.accept_gates.shape,
                         (2, len(builder.in_route_types), 3))
        self.assertEqual(noiseless_gates.backtracKkgates.shape,
                         (2, len(builder.in_route_types), 3))

        # Perturbed
        noisy_gates = builder.initialize_routing_gates(
            key=jax.random.PRNGKey(0),
            logistic_noise=0.2,
            num_fsm_states=3,
            num_variants=2)
        self.assertEqual(noisy_gates.move_gates.shape,
                         (2, len(builder.in_out_route_types), 3, 3))
        self.assertEqual(noisy_gates.accept_gates.shape,
                         (2, len(builder.in_route_types), 3))
        self.assertEqual(noisy_gates.backtracKkgates.shape,
                         (2, len(builder.in_route_types), 3))
Exemple #4
0
 def test_one_node_particle_estimate_padding(self):
   schema, graph = self.build_doubly_linked_list_graph(4)
   builder = automaton_builder.AutomatonBuilder(schema)
   enc_graph, enc_meta = builder.encode_graph(graph)
   enc_graph_padded = automaton_builder.EncodedGraph(
       initial_to_in_tagged=enc_graph.initial_to_in_tagged.pad_nonzeros(64),
       initial_to_special=jax_util.pad_to(enc_graph.initial_to_special, 64),
       in_tagged_to_in_tagged=(
           enc_graph.in_tagged_to_in_tagged.pad_nonzeros(64)),
       in_tagged_to_special=(jax_util.pad_to(enc_graph.in_tagged_to_special,
                                             64)),
       in_tagged_node_indices=(jax_util.pad_to(
           enc_graph.in_tagged_node_indices, 64)))
   enc_meta_padded = automaton_builder.EncodedGraphMetadata(
       num_nodes=64, num_input_tagged_nodes=64)
   variant_weights = jnp.full([64, 5], 0.2)
   routing_params = automaton_builder.RoutingParams(
       move=jnp.full([5, 6, 2, 2], 0.2), special=jnp.full([5, 3, 2, 3], 0.2))
   tmat = builder.build_transition_matrix(routing_params, enc_graph_padded,
                                          enc_meta_padded)
   outs = automaton_sampling.one_node_particle_estimate(
       builder,
       tmat,
       variant_weights,
       start_machine_state=jnp.array([1., 0.]),
       node_index=0,
       steps=100,
       num_rollouts=100,
       max_possible_transitions=2,
       num_valid_nodes=enc_meta.num_nodes,
       rng=jax.random.PRNGKey(0))
   self.assertEqual(outs.shape, (64,))
   self.assertTrue(jnp.all(outs[:enc_meta.num_nodes] > 0))
   self.assertTrue(jnp.all(outs[enc_meta.num_nodes:] == 0))
  def test_initial_routing_params_noiseless(self):
    schema = self.build_simple_schema()
    builder = automaton_builder.AutomatonBuilder(schema)
    routing_params = builder.initialize_routing_params(
        key=None,
        num_fsm_states=3,
        num_variants=2,
        state_change_prob=0.2,
        move_prob=0.9,
        noise_factor=0)

    outgoing_count = np.array([
        len(schema[in_out_route.node_type].out_edges)
        for in_out_route in builder.in_out_route_types
    ])[None, :, None]

    all_same_state_moves = routing_params.move[:, :, np.arange(3), np.arange(3)]
    expected = np.broadcast_to(0.9 * 0.8 / outgoing_count,
                               all_same_state_moves.shape)
    np.testing.assert_allclose(all_same_state_moves, expected)

    state_1 = []
    state_2 = []
    for i in range(3):
      for j in range(3):
        if i != j:
          state_1.append(i)
          state_2.append(j)

    all_different_state_moves = routing_params.move[:, :, state_1, state_2]
    expected = np.broadcast_to(0.9 * 0.2 / (2 * outgoing_count),
                               all_different_state_moves.shape)
    np.testing.assert_allclose(all_different_state_moves, expected)

    np.testing.assert_allclose(routing_params.special, 0.1 / 3)
  def test_routing_gates_to_probs(self):
    builder = automaton_builder.AutomatonBuilder(self.build_simple_schema())

    # [variants, in_out_routes, fsm_states, fsm_states]
    # [variants, in_routes, fsm_states]
    move_gates = np.full([3, len(builder.in_out_route_types), 2, 2], 0.5)
    accept_gates = np.full([3, len(builder.in_route_types), 2], 0.5)
    backtrack_gates = np.full([3, len(builder.in_route_types), 2], 0.5)

    # Set one distribution to sum to more than 1.
    idx_d1_move1 = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
            graph_types.OutEdgeType("bo_0"))]
    move_gates[0, idx_d1_move1, 0, :] = [.2, .3]
    idx_d1_move2 = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
            graph_types.OutEdgeType("bo_1"))]
    move_gates[0, idx_d1_move2, 0, :] = [.4, .5]
    idx_d1_special = builder.in_route_type_to_index[
        automaton_builder.InRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"))]
    accept_gates[0, idx_d1_special, 0] = .6
    backtrack_gates[0, idx_d1_special, 0] = .3

    # Set another to sum to less than 1.
    idx_d2_move = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"),
            graph_types.OutEdgeType("ao_0"))]
    move_gates[2, idx_d2_move, 1, :] = [.1, .2]
    idx_d2_special = builder.in_route_type_to_index[
        automaton_builder.InRouteType(
            graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"))]
    accept_gates[2, idx_d2_special, 1] = .3
    backtrack_gates[2, idx_d2_special, 1] = .75

    routing_gates = automaton_builder.RoutingGateParams(
        move_gates=jax.scipy.special.logit(move_gates),
        accept_gates=jax.scipy.special.logit(accept_gates),
        backtrack_gates=jax.scipy.special.logit(backtrack_gates))
    routing_probs = builder.routing_gates_to_probs(routing_gates)

    # Check probabilities for first distribution: should divide evenly.
    np.testing.assert_allclose(routing_probs.move[0, idx_d1_move1, 0, :],
                               np.array([.2, .3]) / 2.0)
    np.testing.assert_allclose(routing_probs.move[0, idx_d1_move2, 0, :],
                               np.array([.4, .5]) / 2.0)
    np.testing.assert_allclose(routing_probs.special[0, idx_d1_special, 0, :],
                               np.array([.6, 0, 0]) / 2.0)

    # Check probabilities for second distribution: should assign remainder to
    # backtrack and fail.
    np.testing.assert_allclose(routing_probs.move[2, idx_d2_move, 1, :],
                               np.array([.1, .2]))
    np.testing.assert_allclose(routing_probs.special[2, idx_d2_special, 1, :],
                               np.array([.3, .3, .1]))
Exemple #7
0
    def test_constructor_actions_nodes_routes(self):
        builder = automaton_builder.AutomatonBuilder(
            self.build_simple_schema(), with_backtrack=False, with_fail=True)

        self.assertEqual(
            set(builder.special_actions), {
                automaton_builder.SpecialActions.FINISH,
                automaton_builder.SpecialActions.FAIL
            })

        self.assertEqual(
            set(builder.node_types),
            {graph_types.NodeType("a"),
             graph_types.NodeType("b")})

        self.assertEqual(
            set(builder.in_route_types), {
                automaton_builder.InRouteType(
                    graph_types.NodeType("a"),
                    automaton_builder.SOURCE_INITIAL),
                automaton_builder.InRouteType(graph_types.NodeType("a"),
                                              graph_types.InEdgeType("ai_0")),
                automaton_builder.InRouteType(graph_types.NodeType("a"),
                                              graph_types.InEdgeType("ai_1")),
                automaton_builder.InRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL),
                automaton_builder.InRouteType(graph_types.NodeType("b"),
                                              graph_types.InEdgeType("bi_0")),
            })

        self.assertEqual(
            set(builder.in_out_route_types), {
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"),
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"), graph_types.InEdgeType("ai_1"),
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("bo_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("bo_1")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
                    graph_types.OutEdgeType("bo_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
                    graph_types.OutEdgeType("bo_1")),
            })
Exemple #8
0
    def test_transition_sentinel_integers(self):
        """Test that the transition matrix puts each element in the right place."""
        schema, test_graph = self.build_loop_graph()
        builder = automaton_builder.AutomatonBuilder(schema)
        encoded_graph, graph_meta = builder.encode_graph(test_graph,
                                                         as_jax=False)

        # Apply to some sentinel integers to check correct indexing (with only one
        # variant and state, since indexing doesn't use those)
        # Each integer is of the form XYZ where
        #   X = {a:1, b:2}[node_type]
        #   Y = {initial:9, i0:0, i1:1}[in_edge]
        #   Z = {o0:0, o1:1, o2:2, finish:3, backtrack:4, fail:5}[action]
        sentinel_routing_params = automaton_builder.RoutingParams(
            move=jnp.array([[100, 101, 110, 111, 190, 191],
                            [200, 201, 202, 290, 291, 292]]).reshape(
                                (1, 12, 1, 1)),
            special=jnp.array([
                [103, 104, 105],
                [113, 114, 115],
                [193, 194, 195],
                [203, 204, 205],
                [293, 294, 295],
            ]).reshape((1, 5, 1, 3)))

        range_transition_matrix = builder.build_transition_matrix(
            sentinel_routing_params,
            encoded_graph,
            graph_meta,
        ).concatenated_transitions()
        self.assertEqual(range_transition_matrix.shape,
                         (1, 4 + 6, 1, 6 * 1 + 3))

        # pyformat: disable
        # pylint: disable=bad-continuation,bad-whitespace,g-inline-comment-too-close
        expected = np.array([
            #  a0i0 a0i1 a1i0 a1i1       b0i0       b1i0   specials    < next
            [  #   |    |    |    |          |          |    |---------|   current V
                [[0, 0, 191, 0, 0, 190, 193, 194, 195]],  # ┬ a0
                [[0, 190, 0, 0, 191, 0, 193, 194, 195]],  # | a1
                [[0, 0, 0, 290, 292 / 2, 291 + 292 / 2, 293, 294,
                  295]],  # | b0
                [[291, 0, 0, 0, 290 + 292 / 2, 292 / 2, 293, 294,
                  295]],  # ┴ b1
                [[0, 0, 101, 0, 0, 100, 103, 104, 105]],  # ┬ a0i0
                [[0, 0, 111, 0, 0, 110, 113, 114, 115]],  # | a0i1
                [[0, 100, 0, 0, 101, 0, 103, 104, 105]],  # | a1i0
                [[0, 110, 0, 0, 111, 0, 113, 114, 115]],  # | a1i1
                [[0, 0, 0, 200, 202 / 2, 201 + 202 / 2, 203, 204,
                  205]],  # | b0i0
                [[201, 0, 0, 0, 200 + 202 / 2, 202 / 2, 203, 204,
                  205]],  # ┴ b0i1
            ]
        ])
        # pyformat: enable
        # pylint: enable=bad-continuation,bad-whitespace,g-inline-comment-too-close
        np.testing.assert_allclose(range_transition_matrix, expected)
 def __post_init__(self):
     """Populates non-init fields based on `ast_spec`."""
     self.schema = generic_ast_graphs.build_ast_graph_schema(self.ast_spec)
     self.edge_types = sorted({
         graph_edge_util.SAME_IDENTIFIER_EDGE_TYPE,
         *graph_edge_util.PROGRAM_GRAPH_EDGE_TYPES,
         *graph_edge_util.schema_edge_types(self.schema),
         *graph_edge_util.nth_child_edge_types(EDGE_NTH_CHILD_MAX),
     })
     self.builder = automaton_builder.AutomatonBuilder(self.schema)
def automaton_model(example,
                    graph_metadata,
                    edge_types_to_indices,
                    variant_edge_types=(),
                    platt_scale=False,
                    with_backtrack=True):
    """Automaton-based module for edge supervision task.

  Args:
    example: Example to run the automaton on.
    graph_metadata: Statically-known metadata about the graph size. If
      encoded_graph is padded, this should reflect the padded size, not the
      original size.
    edge_types_to_indices: Mapping from edge type names to edge type indices.
    variant_edge_types: Edge types to use as variants. Assumes without checking
      that the given variants are mutually exclusive (at most one edge of one of
      these types exists between any pair of nodes).
    platt_scale: Whether to scale and shift the logits produced by the
      automaton. This can be viewed as a form of Platt scaling applied to the
      automaton logits. If True, this allows the model's output probabilities to
      sum to more than 1, so that it can express one-to-many relations.
    with_backtrack: Whether the automaton can restart the search as an action.

  Returns:
    <float32[num_nodes, num_nodes]> matrix of binary logits for a weighted
    adjacency matrix corresponding to the predicted output edges.
  """
    if variant_edge_types:
        variant_edge_type_indices = [
            edge_types_to_indices[type_str] for type_str in variant_edge_types
        ]
        num_edge_types = len(edge_types_to_indices)
        variant_weights = variants_from_edges(example, graph_metadata,
                                              variant_edge_type_indices,
                                              num_edge_types)
    else:
        variant_weights = None

    absorbing_probs = automaton_layer.FiniteStateGraphAutomaton(
        encoded_graph=example.automaton_graph,
        variant_weights=variant_weights,
        static_metadata=graph_metadata,
        dynamic_metadata=example.graph_metadata,
        builder=automaton_builder.AutomatonBuilder(
            py_ast_graphs.SCHEMA, with_backtrack=with_backtrack),
        num_out_edges=1,
        share_states_across_edges=True).squeeze(axis=0)

    logits = model_util.safe_logit(absorbing_probs)

    if platt_scale:
        logits = model_util.ScaleAndShift(logits)

    return logits
  def test_constructor_inverse_mappings(self):
    builder = automaton_builder.AutomatonBuilder(self.build_simple_schema())

    # Mappings should be inverses of the corresponding lists
    for i, node_type in enumerate(builder.node_types):
      self.assertEqual(builder.node_type_to_index[node_type], i)

    for i, in_route_type in enumerate(builder.in_route_types):
      self.assertEqual(builder.in_route_type_to_index[in_route_type], i)

    for i, in_out_route_type in enumerate(builder.in_out_route_types):
      self.assertEqual(builder.in_out_route_type_to_index[in_out_route_type], i)
  def test_transition_all_ones(self):
    """Test the transition matrix of an all-ones routing parameter vector."""
    schema, test_graph = self.build_loop_graph()
    builder = automaton_builder.AutomatonBuilder(schema)
    encoded_graph, graph_meta = builder.encode_graph(test_graph, as_jax=False)

    # The transition matrix for an all-ones routing params should be a
    # (weighted) directed adjacency matrix. We use 3 variants, 2 states.
    ones_routing_params = automaton_builder.RoutingParams(
        move=jnp.ones([3, 12, 2, 2]), special=jnp.ones([3, 5, 2, 3]))

    ones_transition_matrix = builder.build_transition_matrix(
        ones_routing_params,
        encoded_graph,
        graph_meta,
    ).concatenated_transitions()
    self.assertEqual(ones_transition_matrix.shape, (3, 4 + 6, 2, 6 * 2 + 3))

    # pyformat: disable
    # pylint: disable=bad-continuation,bad-whitespace,g-inline-comment-too-close
    expected = np.array([
        #  a0i0    a0i1    a1i0    a1i1    b0i0    b1i0    specials  < next
      [ #|------|-------|-------|-------|-------|-------| |--------| current V
        [[ 0,  0,  0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  1,  1,  1],  # ┬ a0
         [ 0,  0,  0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  1,  1,  1]], # |
        [[ 0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  0,  0,  1,  1,  1],  # | a1
         [ 0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  0,  0,  1,  1,  1]], # |
        [[ 0,  0,  0,  0,  0,  0,  1,  1,0.5,0.5,1.5,1.5,  1,  1,  1],  # | b0
         [ 0,  0,  0,  0,  0,  0,  1,  1,0.5,0.5,1.5,1.5,  1,  1,  1]], # |
        [[ 1,  1,  0,  0,  0,  0,  0,  0,1.5,1.5,0.5,0.5,  1,  1,  1],  # | b1
         [ 1,  1,  0,  0,  0,  0,  0,  0,1.5,1.5,0.5,0.5,  1,  1,  1]], # ┴
        [[ 0,  0,  0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  1,  1,  1],  # ┬ a0i0
         [ 0,  0,  0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  1,  1,  1]], # |
        [[ 0,  0,  0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  1,  1,  1],  # | a0i1
         [ 0,  0,  0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  1,  1,  1]], # |
        [[ 0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  0,  0,  1,  1,  1],  # | a1i0
         [ 0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  0,  0,  1,  1,  1]], # |
        [[ 0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  0,  0,  1,  1,  1],  # | a1i1
         [ 0,  0,  1,  1,  0,  0,  0,  0,  1,  1,  0,  0,  1,  1,  1]], # |
        [[ 0,  0,  0,  0,  0,  0,  1,  1,0.5,0.5,1.5,1.5,  1,  1,  1],  # | b0i0
         [ 0,  0,  0,  0,  0,  0,  1,  1,0.5,0.5,1.5,1.5,  1,  1,  1]], # |
        [[ 1,  1,  0,  0,  0,  0,  0,  0,1.5,1.5,0.5,0.5,  1,  1,  1],  # | b0i1
         [ 1,  1,  0,  0,  0,  0,  0,  0,1.5,1.5,0.5,0.5,  1,  1,  1]], # ┴
      ]
    ] * 3)
    # pyformat: enable
    # pylint: enable=bad-continuation,bad-whitespace,g-inline-comment-too-close
    np.testing.assert_allclose(ones_transition_matrix, expected)
  def test_graph_encoding_size(self):
    """Test the size of the encoded graph."""
    schema, test_graph = self.build_loop_graph()
    builder = automaton_builder.AutomatonBuilder(schema)
    encoded_graph, graph_meta = builder.encode_graph(test_graph, as_jax=False)

    # Graph metadata should match our graph's actual size
    self.assertEqual(graph_meta.num_nodes, 4)
    self.assertEqual(graph_meta.num_input_tagged_nodes, 6)

    # Nonzero entries should match the number of possible transitions
    # Initial transition counts each NODE once, so each A node has 2 and each
    # B node has 4 outgoing transitions
    self.assertEqual(encoded_graph.initial_to_in_tagged.values.shape[0], 12)

    # Normal transitions count each input-tagged node once, so each A node has
    # 2*2=4 and each B node has 1*4=4 outgoing transitions
    self.assertEqual(encoded_graph.in_tagged_to_in_tagged.values.shape[0], 16)
Exemple #14
0
    def test_all_nodes_absorbing_solve_explicit_conv(self):
        schema, graph = self.build_doubly_linked_list_graph(4)
        builder = automaton_builder.AutomatonBuilder(schema)
        enc_graph, enc_meta = builder.encode_graph(graph)

        variant_weights = jax.random.dirichlet(jax.random.PRNGKey(0),
                                               jnp.ones((4, 4, 5)))
        routing_params = builder.initialize_routing_params(
            jax.random.PRNGKey(1), num_fsm_states=3, num_variants=5)
        start_states = jax.random.dirichlet(jax.random.PRNGKey(0),
                                            jnp.ones((4, 3)))

        # Confirm that the explicit conv doesn't change results.
        def go(routing_params,
               variant_weights,
               start_states,
               explicit_conv=True):
            tmat = builder.build_transition_matrix(routing_params, enc_graph,
                                                   enc_meta)
            return automaton_builder.all_nodes_absorbing_solve(
                builder,
                tmat,
                variant_weights,
                start_states,
                steps=1000,
                backtrack_fails_prob=0.01,
                explicit_conv=explicit_conv)

        vals, vjpfun = jax.vjp(go, routing_params, variant_weights,
                               start_states)
        unopt_vals, unopt_vjpfun = jax.vjp(
            functools.partial(go, explicit_conv=False), routing_params,
            variant_weights, start_states)

        jax.test_util.check_close(vals, unopt_vals)
        some_cotangent = jax.random.normal(jax.random.PRNGKey(0), vals.shape)
        jax.test_util.check_close(vjpfun(some_cotangent),
                                  unopt_vjpfun(some_cotangent))
    def test_component_shapes(self,
                              component,
                              embed_edges,
                              expected_dims,
                              extra_config=None):
        gin.clear_config()
        gin.parse_config(CONFIG)
        if extra_config:
            gin.parse_config(extra_config)

        # Run the computation with placeholder inputs.
        (node_out,
         edge_out), _ = end_to_end_stack.ALL_COMPONENTS[component].init(
             jax.random.PRNGKey(0),
             graph_context=end_to_end_stack.SharedGraphContext(
                 bundle=graph_bundle.zeros_like_padded_example(
                     graph_bundle.PaddingConfig(
                         static_max_metadata=automaton_builder.
                         EncodedGraphMetadata(num_nodes=16,
                                              num_input_tagged_nodes=32),
                         max_initial_transitions=11,
                         max_in_tagged_transitions=12,
                         max_edges=13)),
                 static_metadata=automaton_builder.EncodedGraphMetadata(
                     num_nodes=16, num_input_tagged_nodes=32),
                 edge_types_to_indices={"foo": 0},
                 builder=automaton_builder.AutomatonBuilder({
                     graph_types.NodeType("node"):
                     graph_types.NodeSchema(
                         in_edges=[graph_types.InEdgeType("in")],
                         out_edges=[graph_types.InEdgeType("out")])
                 }),
                 edges_are_embedded=embed_edges),
             node_embeddings=jnp.zeros((16, NODE_DIM)),
             edge_embeddings=jnp.zeros((16, 16, EDGE_DIM)))

        self.assertEqual(node_out.shape, (16, expected_dims["node"]))
        self.assertEqual(edge_out.shape, (16, 16, expected_dims["edge"]))
    def test_automaton_layer_abstract_init(self, shared, variant_weights,
                                           use_gate, estimator_type, **kwargs):
        # Create a simple schema and empty encoded graph.
        schema = {
            graph_types.NodeType("a"):
            graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("ai_0")],
                                   out_edges=[graph_types.OutEdgeType("ao_0")
                                              ]),
        }
        builder = automaton_builder.AutomatonBuilder(schema)
        encoded_graph = automaton_builder.EncodedGraph(
            initial_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((128, 1), dtype=jnp.int32),
                output_indices=jnp.zeros((128, 2), dtype=jnp.int32),
                values=jnp.zeros((128, ), dtype=jnp.float32),
            ),
            initial_to_special=jnp.zeros((32, ), dtype=jnp.int32),
            in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((128, 1), dtype=jnp.int32),
                output_indices=jnp.zeros((128, 2), dtype=jnp.int32),
                values=jnp.zeros((128, ), dtype=jnp.float32),
            ),
            in_tagged_to_special=jnp.zeros((64, ), dtype=jnp.int32),
            in_tagged_node_indices=jnp.zeros((64, ), dtype=jnp.int32),
        )

        # Make sure the layer can be initialized and applied within a model.
        # This model is fairly simple; it just pretends that the encoded graph and
        # variants depend on the input.
        class TestModel(flax.deprecated.nn.Module):
            def apply(self, dummy_ignored):
                abstract_encoded_graph = jax.tree_map(
                    lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph)
                abstract_variant_weights = jax.tree_map(
                    lambda y: jax.lax.tie_in(dummy_ignored, y),
                    variant_weights())
                return automaton_layer.FiniteStateGraphAutomaton(
                    encoded_graph=abstract_encoded_graph,
                    variant_weights=abstract_variant_weights,
                    dynamic_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=32, num_input_tagged_nodes=64),
                    static_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=32, num_input_tagged_nodes=64),
                    builder=builder,
                    num_out_edges=3,
                    num_intermediate_states=4,
                    share_states_across_edges=shared,
                    use_gate_parameterization=use_gate,
                    estimator_type=estimator_type,
                    name="the_layer",
                    **kwargs)

        with side_outputs.collect_side_outputs() as side:
            with flax.deprecated.nn.stochastic(jax.random.PRNGKey(0)):
                # For some reason init_by_shape breaks the custom_vjp?
                abstract_out, unused_params = TestModel.init(
                    jax.random.PRNGKey(1234), jnp.zeros((), jnp.float32))

        del unused_params
        self.assertEqual(abstract_out.shape, (3, 32, 32))

        if estimator_type == "one_sample":
            log_prob_key = "/the_layer/one_sample_log_prob_per_edge_per_node"
            self.assertIn(log_prob_key, side)
            self.assertEqual(side[log_prob_key].shape, (3, 32))
    def test_all_nodes_particle_estimate(self):
        schema, graph = self.build_doubly_linked_list_graph(4)
        builder = automaton_builder.AutomatonBuilder(schema)
        enc_graph, enc_meta = builder.encode_graph(graph)

        # We set up the automaton with 5 variants and 2 states, but only use the
        # first state, to make sure that variants and states are interleaved
        # correctly.

        # Variant 0: move forward
        # Variant 1: move backward
        # Variant 2: finish
        # Variant 3: restart
        # Variant 4: fail
        variant_weights = jnp.array([
            # From node 0, go forward.
            [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [.7, 0, .3, 0, 0],
             [0, 0, 1, 0, 0]],
            # From node 1, go backward with small failure probabilities.
            [[0, 0.9, 0, 0, 0.1], [0, 0.9, 0, 0, 0.1], [.7, 0, .3, 0, 0],
             [0, 0, 1, 0, 0]],
            # Node 2 bounces around and ultimately accepts on node 0.
            [[0.9, 0, 0.1, 0, 0], [0, 1, 0, 0, 0], [0.5, 0.5, 0, 0, 0],
             [0, 1, 0, 0, 0]],
            # Node 3 immediately accepts, or restarts after 0 or 1 steps.
            [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 1, 0],
             [0, 0.1, 0.8, 0.1, 0]],
        ])
        routing_params = automaton_builder.RoutingParams(
            move=jnp.broadcast_to(
                jnp.pad(
                    jnp.array([
                        [1., 0., 1., 0., 1., 0.],
                        [0., 1., 0., 1., 0., 1.],
                        [0., 0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0., 0.],
                        [0., 0., 0., 0., 0., 0.],
                    ]).reshape([5, 6, 1, 1]), [(0, 0), (0, 0), (0, 0),
                                               (0, 1)]), [5, 6, 2, 2]),
            special=jnp.broadcast_to(
                jnp.array([
                    [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
                    [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
                    [[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]],
                    [[0., 1., 0.], [0., 1., 0.], [0., 1., 0.]],
                    [[0., 0., 1.], [0., 0., 1.], [0., 0., 1.]],
                ]).reshape([5, 3, 1, 3]), [5, 3, 2, 3]))

        @jax.jit
        def go(variant_weights, routing_params, eps):
            variant_weights = ((1 - eps) * variant_weights +
                               eps * jnp.ones_like(variant_weights) / 5)
            routing_params = automaton_builder.RoutingParams(
                move=((1 - eps) * routing_params.move +
                      eps * jnp.ones_like(routing_params.move) / 5),
                special=((1 - eps) * routing_params.special +
                         eps * jnp.ones_like(routing_params.special) / 5),
            )
            variant_weights = variant_weights / jnp.sum(
                variant_weights, axis=-1, keepdims=True)
            routing_params_sum = builder.routing_reduce(routing_params, "sum")
            routing_params = jax.tree_multimap(jax.lax.div, routing_params,
                                               routing_params_sum)
            tmat = builder.build_transition_matrix(routing_params, enc_graph,
                                                   enc_meta)
            return automaton_sampling.all_nodes_particle_estimate(
                builder,
                tmat,
                variant_weights,
                jnp.pad(jnp.ones([4, 1]), [(0, 0), (0, 1)]),
                steps=100,
                rng=jax.random.PRNGKey(1),
                num_rollouts=10000,
                max_possible_transitions=2,  # only two edges to leave each node
                num_valid_nodes=enc_meta.num_nodes,
            )

        # Absorbing probs follow the paths described above.
        expected_absorbing_probs = jnp.array([
            [0, 0, 0.3, 0.7],
            [0, 0, 0, 0.81],
            [1, 0, 0, 0],
            [0, 0, 0, 1],
        ])

        particle_probs = go(variant_weights, routing_params, eps=0)

        # With 10000 rollouts we expect a standard deviation of up to
        # sqrt(0.3*0.7/10000) ~= 5e-3. Check that we are within 2 sigma.
        np.testing.assert_allclose(particle_probs,
                                   expected_absorbing_probs,
                                   atol=2 * 5e-3)
Exemple #18
0
    def test_unroll_and_aggregate(self):
        schema, graph = self.build_doubly_linked_list_graph(4)
        builder = automaton_builder.AutomatonBuilder(schema)
        enc_graph, enc_meta = builder.encode_graph(graph)

        # Three variants, but all the same, just to check shape consistency
        variant_weights = jnp.broadcast_to(jnp.array([0.7, 0.3, 0.]), [4, 3])

        # In state 0: keep moving in the current direction with prob 0.1, swap
        # directions and states with prob 0.9 (except init, which is a special case)
        # In state 1: take a special action
        routing_params = automaton_builder.RoutingParams(
            move=jnp.broadcast_to(
                jnp.array([
                    # from next, to next
                    [[0.0, 0.9], [0.0, 0.0]],
                    # from next, to prev
                    [[0.1, 0.0], [0.0, 0.0]],
                    # from prev, to next
                    [[0.1, 0.0], [0.0, 0.0]],
                    # from next, to prev
                    [[0.0, 0.9], [0.0, 0.0]],
                    # from init, to next
                    [[0.1, 0.0], [0.0, 0.0]],
                    # from init, to prev
                    [[0.1, 0.0], [0.0, 0.0]],
                ]),
                [3, 6, 2, 2]),
            special=jnp.broadcast_to(
                jnp.array([
                    # from next
                    [[0, 0, 0], [0.2, 0.3, 0.5]],
                    # from prev
                    [[0, 0, 0], [0.5, 0.2, 0.3]],
                    # from init
                    [[0.1, 0.3, 0.4], [0.0, 0.0, 1.0]],
                ]),
                [3, 3, 2, 3]))
        tmat = builder.build_transition_matrix(routing_params, enc_graph,
                                               enc_meta)

        unrolled = automaton_builder.unroll_chain_steps(builder,
                                                        tmat,
                                                        variant_weights,
                                                        jnp.array([1., 0.]),
                                                        node_index=0,
                                                        steps=6)

        expected_initial_special = np.array([0.1, 0.3, 0.4])
        np.testing.assert_allclose(unrolled["initial_special"],
                                   expected_initial_special)

        # pyformat: disable
        # pylint: disable=bad-whitespace
        expected_in_tagged_states = np.array([
            # In-tagged node key:
            #  0 from 1,     0 from 3,     1 from 2,     1 from 0,
            #  2 from 3,     2 from 1,     3 from 0,     3 from 2
            # State key: [prob of being in state 0, prob of being in state 1]
            # -- First step from initial --
            [[0, 0], [0, 0], [0, 0], [1e-1, 0], [0, 0], [0, 0], [1e-1, 0],
             [0, 0]],
            # -- Second step --
            [[0, 9e-2], [0, 9e-2], [0, 0], [0, 0], [1e-2, 0], [1e-2, 0],
             [0, 0], [0, 0]],
            # -----------------
            [[0, 0], [0, 0], [1e-3, 9e-3], [0, 0], [0, 0], [0, 0], [0, 0],
             [1e-3, 9e-3]],
            # -----------------
            [[1e-4, 0], [1e-4, 0], [0, 0], [0, 0], [0, 9e-4], [0, 9e-4],
             [0, 0], [0, 0]],
            # -----------------
            [[0, 0], [0, 0], [0, 0], [1e-5, 9e-5], [0, 0], [0, 0],
             [1e-5, 9e-5], [0, 0]],
            # -----------------
            [[0, 9e-6], [0, 9e-6], [0, 0], [0, 0], [1e-6, 0], [1e-6, 0],
             [0, 0], [0, 0]],
        ])
        np.testing.assert_allclose(unrolled["in_tagged_states"],
                                   expected_in_tagged_states,
                                   atol=1e-8)

        expected_in_tagged_special = np.array([
            # In-tagged node key:
            #  0 from 1,                 0 from 3
            #  1 from 2,                 1 from 0,
            #  2 from 3,                 2 from 1
            #  3 from 0,                 3 from 2
            # Action key: [finish, backtrack, fail] (cumulative)
            # -- First step from initial --
            # (no special actions because we just left the initial node)
            [[0, 0, 0]] * 8,
            # -- Second step --
            # (no special actions yet because everything was in state 0)
            [[0, 0, 0]] * 8,
            # -----------------
            [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2], [0, 0, 0],
             [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
            # -----------------
            [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2],
             [1.8e-3, 2.7e-3, 4.5e-3], [0, 0, 0], [0, 0, 0], [0, 0, 0],
             [0, 0, 0], [4.5e-3, 1.8e-3, 2.7e-3]],
            # -----------------
            [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2],
             [1.8e-3, 2.7e-3, 4.5e-3], [0, 0, 0], [1.8e-4, 2.7e-4, 4.5e-4],
             [4.5e-4, 1.8e-4, 2.7e-4], [0, 0, 0], [4.5e-3, 1.8e-3, 2.7e-3]],
            # -----------------
            [[1.8e-2, 2.7e-2, 4.5e-2], [4.5e-2, 1.8e-2, 2.7e-2],
             [1.8e-3, 2.7e-3, 4.5e-3], [4.5e-5, 1.8e-5, 2.7e-5],
             [1.8e-4, 2.7e-4, 4.5e-4], [4.5e-4, 1.8e-4, 2.7e-4],
             [1.8e-5, 2.7e-5, 4.5e-5], [4.5e-3, 1.8e-3, 2.7e-3]],
        ])
        np.testing.assert_allclose(unrolled["in_tagged_special"],
                                   expected_in_tagged_special,
                                   atol=1e-8)
        # pyformat: enable
        # pylint: enable=bad-whitespace

        unrolled_combined = automaton_builder.aggregate_unrolled_per_node(
            unrolled, 0, 0, tmat, enc_meta)

        expected_unrolled = np.array([
            # "Zeroth" step: at initial node, no specials have happened yet
            [[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0]],
            # The other steps are either the states from above, or sums of entries
            # in initial_special and in_tagged_special
            [[0, 0, 0.1, 0.3, 0.4], [0.1, 0, 0, 0, 0], [0, 0, 0, 0, 0],
             [0.1, 0, 0, 0, 0]],
            # -----------------
            [[0, 0.18, 0.1, 0.3, 0.4], [0, 0, 0, 0, 0], [0.02, 0, 0, 0, 0],
             [0, 0, 0, 0, 0]],
            # -----------------
            [[0, 0, 0.163, 0.34500003, 0.472], [0.001, 0.009, 0, 0, 0],
             [0, 0, 0, 0, 0], [0.001, 0.009, 0, 0, 0]],
            # -----------------
            [[0.0002, 0, 0.163, 0.34500003, 0.472],
             [0, 0, 0.0018, 0.0027, 0.0045], [0, 0.0018, 0, 0, 0],
             [0, 0, 0.0045, 0.0018, 0.0027]],
            # -----------------
            [[0, 0, 0.163, 0.34500003, 0.472],
             [1e-05, 9e-05, 0.0018, 0.0027, 0.0045],
             [0, 0, 0.00063, 0.00045, 0.00072],
             [1e-05, 9e-05, 0.0045, 0.0018, 0.0027]],
            # -----------------
            [[0, 1.8e-05, 0.163, 0.34500003, 0.472],
             [0, 0, 0.001845, 0.002718, 0.004527],
             [2e-06, 0, 0.00063, 0.00045, 0.00072],
             [0, 0, 0.004518, 0.001827, 0.002745]],
        ])

        np.testing.assert_allclose(unrolled_combined,
                                   expected_unrolled,
                                   atol=1e-8)
Exemple #19
0
    def test_all_nodes_absorbing_solve(self):
        schema, graph = self.build_doubly_linked_list_graph(4)
        builder = automaton_builder.AutomatonBuilder(schema)
        enc_graph, enc_meta = builder.encode_graph(graph)

        # We set up the automaton with 5 variants and 2 states, but only use the
        # first state, to make sure that variants and states are interleaved
        # correctly.

        # Variant 0: move forward
        # Variant 1: move backward
        # Variant 2: finish
        # Variant 3: restart
        # Variant 4: fail
        variant_weights = jnp.array([
            # From node 0, go forward.
            [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [.7, 0, .3, 0, 0],
             [0, 0, 1, 0, 0]],
            # From node 1, go backward with small failure probabilities.
            [[0, 0.9, 0, 0, 0.1], [0, 0.9, 0, 0, 0.1], [.7, 0, .3, 0, 0],
             [0, 0, 1, 0, 0]],
            # Node 2 bounces around and ultimately accepts on node 0.
            [[0.9, 0, 0.1, 0, 0], [0, 1, 0, 0, 0], [0.5, 0.5, 0, 0, 0],
             [0, 1, 0, 0, 0]],
            # Node 3 immediately accepts, or restarts after 0 or 1 steps.
            [[0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 0, 0, 1, 0],
             [0, 0.1, 0.8, 0.1, 0]],
        ])
        routing_params = automaton_builder.RoutingParams(
            move=jnp.pad(
                jnp.array([
                    [1., 0., 1., 0., 1., 0.],
                    [0., 1., 0., 1., 0., 1.],
                    [0., 0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0., 0.],
                ]).reshape([5, 6, 1, 1]), [(0, 0), (0, 0), (0, 1), (0, 1)]),
            special=jnp.pad(
                jnp.array([
                    [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
                    [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
                    [[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]],
                    [[0., 1., 0.], [0., 1., 0.], [0., 1., 0.]],
                    [[0., 0., 1.], [0., 0., 1.], [0., 0., 1.]],
                ]).reshape([5, 3, 1, 3]), [(0, 0), (0, 0), (0, 1), (0, 0)]))
        tmat = builder.build_transition_matrix(routing_params, enc_graph,
                                               enc_meta)

        # Absorbing probs follow the paths described above.
        # Note that when starting at node 3, with probability 0.2 the automaton
        # tries to backtrack, but with probability 0.2 * 0.01 backtracking fails
        # (as specified by backtrack_fails_prob) and thus the total absorbing
        # probability is 0.8 / (0.8 + 0.2 * 0.01) = 0.997506
        expected_absorbing_probs = jnp.array([
            [0, 0, 0.3, 0.7],
            [0, 0, 0, 0.81],
            [1, 0, 0, 0],
            [0, 0, 0, 0.997506],
        ])
        absorbing_probs = automaton_builder.all_nodes_absorbing_solve(
            builder,
            tmat,
            variant_weights,
            jnp.pad(jnp.ones([4, 1]), [(0, 0), (0, 1)]),
            steps=1000,
            backtrack_fails_prob=0.01)

        jax.test_util.check_close(absorbing_probs, expected_absorbing_probs)
Exemple #20
0
  Args:
    maze_graph: Encoded graph representing the maze.

  Returns:
    List of edges corresponding to primitive actions in the maze.
  """
    primitives = []
    for node_id, node_info in maze_graph.items():
        for i, direction in enumerate(DIRECTION_ORDERING):
            out_key = graph_types.OutEdgeType(f"{direction}_out")
            if out_key in node_info.out_edges:
                dest, = node_info.out_edges[out_key]
                primitives.append((node_id, dest.node_id, i))
            else:
                primitives.append((node_id, node_id, i))

    return primitives


SCHEMA = maze_schema.build_maze_schema(2)

# Backtracking doesn't make sense for maze environment.
BUILDER = automaton_builder.AutomatonBuilder(SCHEMA, with_backtrack=False)

PADDING_CONFIG = graph_bundle.PaddingConfig(
    static_max_metadata=automaton_builder.EncodedGraphMetadata(
        num_nodes=256, num_input_tagged_nodes=512),
    max_initial_transitions=512,
    max_in_tagged_transitions=2048,
    max_edges=1024)
Exemple #21
0
    def test_routing_reduce_correct(self, reduction):
        """Compare JAX implementations to a (slow but correct) iterative one."""
        n_variants = 2
        n_states = 4

        def make_range_shaped(shape):
            return np.arange(np.prod(shape)).reshape(shape).astype("float32")

        schema = self.build_simple_schema()
        builder = automaton_builder.AutomatonBuilder(schema)
        routing_params = automaton_builder.RoutingParams(
            move=make_range_shaped([
                n_variants,
                len(builder.in_out_route_types),
                n_states,
                n_states,
            ]),
            special=make_range_shaped([
                n_variants,
                len(builder.in_route_types),
                n_states,
                len(builder.special_actions),
            ]),
        )

        # Compute aggregates with JAX
        if reduction == "softmax":
            routing_aggregates = builder.routing_softmax(routing_params)
        else:
            routing_aggregates = builder.routing_reduce(routing_params,
                                                        reduction=reduction)
            routing_aggregates = jax.tree_multimap(
                lambda s, p: np.array(jnp.broadcast_to(s, p.shape)),
                routing_aggregates, routing_params)

        # Manual looping aggregates
        for variant in range(n_variants):
            for current_state in range(n_states):
                for in_route_type in builder.in_route_types:
                    # Compute aggregates
                    distn_vals = []
                    iroute_idx = builder.in_route_type_to_index[in_route_type]
                    for out_edge_type in schema[
                            in_route_type.node_type].out_edges:
                        ioroute_idx = builder.in_out_route_type_to_index[
                            automaton_builder.InOutRouteType(
                                in_route_type.node_type, in_route_type.in_edge,
                                out_edge_type)]
                        for next_state in range(n_states):
                            distn_vals.append(
                                routing_params.move[variant, ioroute_idx,
                                                    current_state, next_state])

                    for action_idx in range(len(builder.special_actions)):
                        distn_vals.append(routing_params.special[variant,
                                                                 iroute_idx,
                                                                 current_state,
                                                                 action_idx])

                    if reduction == "sum":
                        distn_aggregate = [sum(distn_vals)] * len(distn_vals)
                    elif reduction == "max":
                        distn_aggregate = [max(distn_vals)] * len(distn_vals)
                    elif reduction == "softmax":
                        distn_aggregate = list(
                            jax.nn.softmax(jnp.array(distn_vals)))
                    else:
                        raise ValueError(f"Invalid reduction {reduction}")

                    i = 0
                    # Check them with the JAX version
                    for out_edge_type in schema[
                            in_route_type.node_type].out_edges:
                        ioroute_idx = builder.in_out_route_type_to_index[
                            automaton_builder.InOutRouteType(
                                in_route_type.node_type, in_route_type.in_edge,
                                out_edge_type)]
                        for next_state in range(n_states):
                            np.testing.assert_allclose(
                                routing_aggregates.move[variant, ioroute_idx,
                                                        current_state,
                                                        next_state],
                                distn_aggregate[i],
                                rtol=1e-6)
                            i += 1

                    for action_idx in range(len(builder.special_actions)):
                        np.testing.assert_allclose(
                            routing_aggregates.special[variant, iroute_idx,
                                                       current_state,
                                                       action_idx],
                            distn_aggregate[i],
                            rtol=1e-6)
                        i += 1
Exemple #22
0
    if isinstance(value, gast.AST):
      fields[field_name] = [py_ast_to_generic(value)]
    elif isinstance(value, list):
      if value and isinstance(value[0], gast.AST):
        fields[field_name] = [py_ast_to_generic(child) for child in value]
    else:
      # Doesn't contain any AST nodes, so ignore it.
      pass

  return generic_ast_graphs.GenericASTNode(
      node_id=id(tree), node_type=type(tree).__name__, fields=fields)


# Default definitions used elsewhere
SCHEMA = generic_ast_graphs.build_ast_graph_schema(PY_AST_SPECS)
BUILDER = automaton_builder.AutomatonBuilder(SCHEMA)


def py_ast_to_graph(
    tree):
  """Convert an unsimplified AST into a graph, with a forward mapping.

  Args:
    tree: The (unsimplified) AST for the program.

  Returns:
    - Graph representing the tree.
    - Dictionary that maps from AST node ids to graph node ids.
  """
  return generic_ast_graphs.ast_to_graph(
      root=py_ast_to_generic(tree),