def build_doubly_linked_list_graph(self, length):
     """Helper method to build a doubly-linked-list graph and schema."""
     schema = {
         graph_types.NodeType("node"):
         graph_types.NodeSchema(in_edges=[
             graph_types.InEdgeType("next_in"),
             graph_types.InEdgeType("prev_in"),
         ],
                                out_edges=[
                                    graph_types.OutEdgeType("next_out"),
                                    graph_types.OutEdgeType("prev_out"),
                                ])
     }
     graph = {}
     for i in range(length):
         graph[graph_types.NodeId(str(i))] = graph_types.GraphNode(
             graph_types.NodeType("node"), {
                 graph_types.OutEdgeType("next_out"): [
                     graph_types.InputTaggedNode(
                         node_id=graph_types.NodeId(str((i + 1) % length)),
                         in_edge=graph_types.InEdgeType("prev_in"))
                 ],
                 graph_types.OutEdgeType("prev_out"): [
                     graph_types.InputTaggedNode(
                         node_id=graph_types.NodeId(str((i - 1) % length)),
                         in_edge=graph_types.InEdgeType("next_in"))
                 ]
             })
     return schema, graph
  def test_routing_gates_to_probs(self):
    builder = automaton_builder.AutomatonBuilder(self.build_simple_schema())

    # [variants, in_out_routes, fsm_states, fsm_states]
    # [variants, in_routes, fsm_states]
    move_gates = np.full([3, len(builder.in_out_route_types), 2, 2], 0.5)
    accept_gates = np.full([3, len(builder.in_route_types), 2], 0.5)
    backtrack_gates = np.full([3, len(builder.in_route_types), 2], 0.5)

    # Set one distribution to sum to more than 1.
    idx_d1_move1 = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
            graph_types.OutEdgeType("bo_0"))]
    move_gates[0, idx_d1_move1, 0, :] = [.2, .3]
    idx_d1_move2 = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
            graph_types.OutEdgeType("bo_1"))]
    move_gates[0, idx_d1_move2, 0, :] = [.4, .5]
    idx_d1_special = builder.in_route_type_to_index[
        automaton_builder.InRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"))]
    accept_gates[0, idx_d1_special, 0] = .6
    backtrack_gates[0, idx_d1_special, 0] = .3

    # Set another to sum to less than 1.
    idx_d2_move = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"),
            graph_types.OutEdgeType("ao_0"))]
    move_gates[2, idx_d2_move, 1, :] = [.1, .2]
    idx_d2_special = builder.in_route_type_to_index[
        automaton_builder.InRouteType(
            graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"))]
    accept_gates[2, idx_d2_special, 1] = .3
    backtrack_gates[2, idx_d2_special, 1] = .75

    routing_gates = automaton_builder.RoutingGateParams(
        move_gates=jax.scipy.special.logit(move_gates),
        accept_gates=jax.scipy.special.logit(accept_gates),
        backtrack_gates=jax.scipy.special.logit(backtrack_gates))
    routing_probs = builder.routing_gates_to_probs(routing_gates)

    # Check probabilities for first distribution: should divide evenly.
    np.testing.assert_allclose(routing_probs.move[0, idx_d1_move1, 0, :],
                               np.array([.2, .3]) / 2.0)
    np.testing.assert_allclose(routing_probs.move[0, idx_d1_move2, 0, :],
                               np.array([.4, .5]) / 2.0)
    np.testing.assert_allclose(routing_probs.special[0, idx_d1_special, 0, :],
                               np.array([.6, 0, 0]) / 2.0)

    # Check probabilities for second distribution: should assign remainder to
    # backtrack and fail.
    np.testing.assert_allclose(routing_probs.move[2, idx_d2_move, 1, :],
                               np.array([.1, .2]))
    np.testing.assert_allclose(routing_probs.special[2, idx_d2_special, 1, :],
                               np.array([.3, .3, .1]))
Esempio n. 3
0
    def test_all_input_tagged_nodes(self):
        # (note: python3 dicts maintain order, so B2 comes before B1)
        graph = {
            graph_types.NodeId("A"):
            graph_types.GraphNode(
                graph_types.NodeType("a"), {
                    graph_types.OutEdgeType("ao_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B1"),
                            graph_types.InEdgeType("bi_1")),
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ]
                }),
            graph_types.NodeId("B2"):
            graph_types.GraphNode(
                graph_types.NodeType("b"), {
                    graph_types.OutEdgeType("bo_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ],
                    graph_types.OutEdgeType("bo_1"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B1"),
                            graph_types.InEdgeType("bi_0"))
                    ]
                }),
            graph_types.NodeId("B1"):
            graph_types.GraphNode(
                graph_types.NodeType("b"), {
                    graph_types.OutEdgeType("bo_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ],
                    graph_types.OutEdgeType("bo_1"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B2"),
                            graph_types.InEdgeType("bi_0"))
                    ]
                }),
        }
        expected_itns = [
            graph_types.InputTaggedNode(graph_types.NodeId("A"),
                                        graph_types.InEdgeType("ai_1")),
            graph_types.InputTaggedNode(graph_types.NodeId("B2"),
                                        graph_types.InEdgeType("bi_0")),
            graph_types.InputTaggedNode(graph_types.NodeId("B1"),
                                        graph_types.InEdgeType("bi_0")),
            graph_types.InputTaggedNode(graph_types.NodeId("B1"),
                                        graph_types.InEdgeType("bi_1")),
        ]

        actual_itns = schema_util.all_input_tagged_nodes(graph)
        self.assertEqual(actual_itns, expected_itns)
Esempio n. 4
0
 def build_simple_schema(self):
     return {
         graph_types.NodeType("a"):
         graph_types.NodeSchema(in_edges=[
             graph_types.InEdgeType("ai_0"),
             graph_types.InEdgeType("ai_1")
         ],
                                out_edges=[graph_types.OutEdgeType("ao_0")
                                           ]),
         graph_types.NodeType("b"):
         graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")],
                                out_edges=[
                                    graph_types.OutEdgeType("bo_0"),
                                    graph_types.OutEdgeType("bo_1")
                                ]),
     }
Esempio n. 5
0
    def test_constructor_information_removing_mappings(self):
        builder = automaton_builder.AutomatonBuilder(
            self.build_simple_schema())

        # Check consistency of information-removing mappings with the
        # corresponding pairs of lists.
        for in_out_route_type in builder.in_out_route_types:
            in_route_type = automaton_builder.InRouteType(
                in_out_route_type.node_type, in_out_route_type.in_edge)
            self.assertEqual(
                builder.in_out_route_to_in_route[
                    builder.in_out_route_type_to_index[in_out_route_type]],
                builder.in_route_type_to_index[in_route_type])

        for in_route_type in builder.in_route_types:
            node_type = graph_types.NodeType(in_route_type.node_type)
            self.assertEqual(
                builder.in_route_to_node_type[
                    builder.in_route_type_to_index[in_route_type]],
                builder.node_type_to_index[node_type])

        for in_out_route_type in builder.in_out_route_types:
            in_route_type = automaton_builder.InRouteType(
                in_out_route_type.node_type, in_out_route_type.in_edge)
            self.assertEqual(
                builder.in_out_route_to_in_route[
                    builder.in_out_route_type_to_index[in_out_route_type]],
                builder.in_route_type_to_index[in_route_type])
Esempio n. 6
0
    def test_conforms_to_schema(self):
        test_schema = {
            graph_types.NodeType("a"):
            graph_types.NodeSchema(in_edges=[
                graph_types.InEdgeType("ai_0"),
                graph_types.InEdgeType("ai_1")
            ],
                                   out_edges=[graph_types.OutEdgeType("ao_0")
                                              ]),
            graph_types.NodeType("b"):
            graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")],
                                   out_edges=[
                                       graph_types.OutEdgeType("bo_0"),
                                       graph_types.OutEdgeType("bo_1")
                                   ]),
        }

        # Valid graph
        test_graph = {
            graph_types.NodeId("A"):
            graph_types.GraphNode(
                graph_types.NodeType("a"), {
                    graph_types.OutEdgeType("ao_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B"),
                            graph_types.InEdgeType("bi_0")),
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ]
                }),
            graph_types.NodeId("B"):
            graph_types.GraphNode(
                graph_types.NodeType("b"), {
                    graph_types.OutEdgeType("bo_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ],
                    graph_types.OutEdgeType("bo_1"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B"),
                            graph_types.InEdgeType("bi_0"))
                    ]
                })
        }
        schema_util.assert_conforms_to_schema(test_graph, test_schema)
Esempio n. 7
0
 def test_does_not_conform_to_schema(self, graph, expected_error):
     test_schema = {
         graph_types.NodeType("a"):
         graph_types.NodeSchema(in_edges=[
             graph_types.InEdgeType("ai_0"),
             graph_types.InEdgeType("ai_1")
         ],
                                out_edges=[graph_types.OutEdgeType("ao_0")
                                           ]),
         graph_types.NodeType("b"):
         graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")],
                                out_edges=[
                                    graph_types.OutEdgeType("bo_0"),
                                    graph_types.OutEdgeType("bo_1")
                                ]),
     }
     # pylint: disable=g-error-prone-assert-raises
     with self.assertRaisesRegex(ValueError, expected_error):
         schema_util.assert_conforms_to_schema(graph, test_schema)
Esempio n. 8
0
 def get_graph_node_id(
     type_name,
     path_from_root,
 ):
     """Builds an ID for a graph node, and creates the node if necessary."""
     id_parts = ["root"]
     for part in path_from_root:
         id_parts.extend(("_", str(part)))
     id_parts.extend(("__", type_name))
     node_id = graph_types.NodeId("".join(id_parts))
     if node_id not in result:
         result[node_id] = graph_types.GraphNode(
             node_type=graph_types.NodeType(type_name), out_edges={})
     return node_id
def encode_maze(maze):
    """Encode a boolean mask array as a graph.

  We assume a [row, col]-indexed coordinate system, oriented so that going right
  corresponds to changing indexes as (0, +1), and going down corresponds to
  changing indies as (+1, 0).

  Args:
    maze: Maze, as a boolean array <bool[width, height]>, where True corresponds
      to cells that should be nodes in the graph.

  Returns:
    Encoded graph, along with a list of the coordinate indices of each cell in
    the graph.
  """
    # (direction, reverse direction, (dr, dc))
    shifts = [
        ("L", "R", (0, -1)),
        ("R", "L", (0, 1)),
        ("U", "D", (-1, 0)),
        ("D", "U", (1, 0)),
    ]
    graph = {}
    idx_to_cell = []
    for i in range(maze.shape[0]):
        for j in range(maze.shape[1]):
            if maze[i, j]:
                typename = "cell_"
                out_edges = {}
                for name, in_name, (di, dj) in shifts:
                    ni = i + di
                    nj = j + dj
                    if (0 <= ni < maze.shape[0] and 0 <= nj < maze.shape[1]
                            and maze[ni, nj]):
                        typename = typename + name
                        out_edges[graph_types.OutEdgeType(f"{name}_out")] = [
                            graph_types.InputTaggedNode(
                                graph_types.NodeId(f"cell_{ni}_{nj}"),
                                graph_types.InEdgeType(f"{in_name}_in"))
                        ]
                    else:
                        typename = typename + "x"
                graph[graph_types.NodeId(
                    f"cell_{i}_{j}")] = graph_types.GraphNode(
                        graph_types.NodeType(typename), out_edges)
                idx_to_cell.append((i, j))

    return graph, idx_to_cell
Esempio n. 10
0
def build_maze_schema(min_neighbors):
    """Build a schema for a 2d gridworld maze.

  A node type like "cell_LxUD" corresponds to a cell that has neighbors in the
  left, up, and down directions. The in and out edges for this cell would be
  "L_out", "U_out", "D_out", "L_in", "U_in", "D_in".

  Args:
    min_neighbors: Minimum number of neighbors a grid cell will have. For
      instance, if min_neighbors=2 then the schema will only have types for
      cells that are connected to at least 2 other cells.

  Returns:
    Schema describing the maze.
  """
    maze_schema = {}
    for (has_l, has_r, has_u, has_d) in itertools.product([False, True],
                                                          repeat=4):
        if np.count_nonzero([has_l, has_r, has_u, has_d]) < min_neighbors:
            continue

        name = (f"cell_{'L' if has_l else 'x'}{'R' if has_r else 'x'}"
                f"{'U' if has_u else 'x'}{'D' if has_d else 'x'}")
        node_schema = graph_types.NodeSchema([], [])
        for direction, used in (
            ("L", has_l),
            ("R", has_r),
            ("U", has_u),
            ("D", has_d),
        ):
            if used:
                node_schema.in_edges.append(
                    graph_types.InEdgeType(direction + "_in"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(direction + "_out"))

        maze_schema[graph_types.NodeType(name)] = node_schema
    return maze_schema
    def test_component_shapes(self,
                              component,
                              embed_edges,
                              expected_dims,
                              extra_config=None):
        gin.clear_config()
        gin.parse_config(CONFIG)
        if extra_config:
            gin.parse_config(extra_config)

        # Run the computation with placeholder inputs.
        (node_out,
         edge_out), _ = end_to_end_stack.ALL_COMPONENTS[component].init(
             jax.random.PRNGKey(0),
             graph_context=end_to_end_stack.SharedGraphContext(
                 bundle=graph_bundle.zeros_like_padded_example(
                     graph_bundle.PaddingConfig(
                         static_max_metadata=automaton_builder.
                         EncodedGraphMetadata(num_nodes=16,
                                              num_input_tagged_nodes=32),
                         max_initial_transitions=11,
                         max_in_tagged_transitions=12,
                         max_edges=13)),
                 static_metadata=automaton_builder.EncodedGraphMetadata(
                     num_nodes=16, num_input_tagged_nodes=32),
                 edge_types_to_indices={"foo": 0},
                 builder=automaton_builder.AutomatonBuilder({
                     graph_types.NodeType("node"):
                     graph_types.NodeSchema(
                         in_edges=[graph_types.InEdgeType("in")],
                         out_edges=[graph_types.InEdgeType("out")])
                 }),
                 edges_are_embedded=embed_edges),
             node_embeddings=jnp.zeros((16, NODE_DIM)),
             edge_embeddings=jnp.zeros((16, 16, EDGE_DIM)))

        self.assertEqual(node_out.shape, (16, expected_dims["node"]))
        self.assertEqual(edge_out.shape, (16, 16, expected_dims["edge"]))
Esempio n. 12
0
class SchemaUtilTest(parameterized.TestCase):
    def test_conforms_to_schema(self):
        test_schema = {
            graph_types.NodeType("a"):
            graph_types.NodeSchema(in_edges=[
                graph_types.InEdgeType("ai_0"),
                graph_types.InEdgeType("ai_1")
            ],
                                   out_edges=[graph_types.OutEdgeType("ao_0")
                                              ]),
            graph_types.NodeType("b"):
            graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")],
                                   out_edges=[
                                       graph_types.OutEdgeType("bo_0"),
                                       graph_types.OutEdgeType("bo_1")
                                   ]),
        }

        # Valid graph
        test_graph = {
            graph_types.NodeId("A"):
            graph_types.GraphNode(
                graph_types.NodeType("a"), {
                    graph_types.OutEdgeType("ao_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B"),
                            graph_types.InEdgeType("bi_0")),
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ]
                }),
            graph_types.NodeId("B"):
            graph_types.GraphNode(
                graph_types.NodeType("b"), {
                    graph_types.OutEdgeType("bo_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ],
                    graph_types.OutEdgeType("bo_1"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B"),
                            graph_types.InEdgeType("bi_0"))
                    ]
                })
        }
        schema_util.assert_conforms_to_schema(test_graph, test_schema)

    @parameterized.named_parameters(
        {
            "testcase_name": "missing_node",
            "graph": {
                graph_types.NodeId("A"):
                graph_types.GraphNode(
                    graph_types.NodeType("a"), {
                        graph_types.OutEdgeType("ao_0"): [
                            graph_types.InputTaggedNode(
                                graph_types.NodeId("B"),
                                graph_types.InEdgeType("bi_0"))
                        ]
                    })
            },
            "expected_error": "Node A has connection to missing node B"
        }, {
            "testcase_name": "bad_node_type",
            "graph": {
                graph_types.NodeId("A"):
                graph_types.GraphNode(
                    graph_types.NodeType("z"), {
                        graph_types.OutEdgeType("ao_0"): [
                            graph_types.InputTaggedNode(
                                graph_types.NodeId("A"),
                                graph_types.InEdgeType("ai_0"))
                        ]
                    })
            },
            "expected_error": "Node A's type z not in schema"
        }, {
            "testcase_name": "missing_out_edge",
            "graph": {
                graph_types.NodeId("A"):
                graph_types.GraphNode(graph_types.NodeType("a"),
                                      {graph_types.OutEdgeType("ao_0"): []})
            },
            "expected_error": "Node A missing out edge of type ao_0"
        }, {
            "testcase_name": "bad_out_edge_type",
            "graph": {
                graph_types.NodeId("A"):
                graph_types.GraphNode(
                    graph_types.NodeType("a"), {
                        graph_types.OutEdgeType("ao_0"): [
                            graph_types.InputTaggedNode(
                                graph_types.NodeId("A"),
                                graph_types.InEdgeType("ai_0"))
                        ],
                        "foo": [
                            graph_types.InputTaggedNode(
                                graph_types.NodeId("A"),
                                graph_types.InEdgeType("ai_0"))
                        ]
                    })
            },
            "expected_error": "Node A has out-edges of invalid type foo"
        }, {
            "testcase_name": "bad_in_edge_type",
            "graph": {
                graph_types.NodeId("A"):
                graph_types.GraphNode(
                    graph_types.NodeType("a"), {
                        graph_types.OutEdgeType("ao_0"): [
                            graph_types.InputTaggedNode(
                                graph_types.NodeId("A"),
                                graph_types.InEdgeType("bar"))
                        ],
                    })
            },
            "expected_error": "Node A has in-edges of invalid type bar"
        })
    def test_does_not_conform_to_schema(self, graph, expected_error):
        test_schema = {
            graph_types.NodeType("a"):
            graph_types.NodeSchema(in_edges=[
                graph_types.InEdgeType("ai_0"),
                graph_types.InEdgeType("ai_1")
            ],
                                   out_edges=[graph_types.OutEdgeType("ao_0")
                                              ]),
            graph_types.NodeType("b"):
            graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")],
                                   out_edges=[
                                       graph_types.OutEdgeType("bo_0"),
                                       graph_types.OutEdgeType("bo_1")
                                   ]),
        }
        # pylint: disable=g-error-prone-assert-raises
        with self.assertRaisesRegex(ValueError, expected_error):
            schema_util.assert_conforms_to_schema(graph, test_schema)
        # pylint: enable=g-error-prone-assert-raises

    def test_all_input_tagged_nodes(self):
        # (note: python3 dicts maintain order, so B2 comes before B1)
        graph = {
            graph_types.NodeId("A"):
            graph_types.GraphNode(
                graph_types.NodeType("a"), {
                    graph_types.OutEdgeType("ao_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B1"),
                            graph_types.InEdgeType("bi_1")),
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ]
                }),
            graph_types.NodeId("B2"):
            graph_types.GraphNode(
                graph_types.NodeType("b"), {
                    graph_types.OutEdgeType("bo_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ],
                    graph_types.OutEdgeType("bo_1"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B1"),
                            graph_types.InEdgeType("bi_0"))
                    ]
                }),
            graph_types.NodeId("B1"):
            graph_types.GraphNode(
                graph_types.NodeType("b"), {
                    graph_types.OutEdgeType("bo_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ],
                    graph_types.OutEdgeType("bo_1"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B2"),
                            graph_types.InEdgeType("bi_0"))
                    ]
                }),
        }
        expected_itns = [
            graph_types.InputTaggedNode(graph_types.NodeId("A"),
                                        graph_types.InEdgeType("ai_1")),
            graph_types.InputTaggedNode(graph_types.NodeId("B2"),
                                        graph_types.InEdgeType("bi_0")),
            graph_types.InputTaggedNode(graph_types.NodeId("B1"),
                                        graph_types.InEdgeType("bi_0")),
            graph_types.InputTaggedNode(graph_types.NodeId("B1"),
                                        graph_types.InEdgeType("bi_1")),
        ]

        actual_itns = schema_util.all_input_tagged_nodes(graph)
        self.assertEqual(actual_itns, expected_itns)
Esempio n. 13
0
    def test_encode_maze(self):
        """Tests that an encoded maze is correct and matches the schema."""

        maze = np.array([
            [1, 1, 1, 1, 1],
            [1, 0, 1, 1, 1],
            [1, 1, 1, 1, 1],
        ]).astype(bool)

        encoded_graph, coordinates = maze_schema.encode_maze(maze)

        # Check coordinates.
        expected_coords = []
        for r in range(3):
            for c in range(5):
                if (r, c) != (1, 1):
                    expected_coords.append((r, c))

        self.assertEqual(coordinates, expected_coords)

        # Check a few nodes.
        self.assertEqual(
            encoded_graph[graph_types.NodeId("cell_0_0")],
            graph_types.GraphNode(
                graph_types.NodeType("cell_xRxD"), {
                    graph_types.OutEdgeType("R_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_0_1"),
                            graph_types.InEdgeType("L_in"))
                    ],
                    graph_types.OutEdgeType("D_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_1_0"),
                            graph_types.InEdgeType("U_in"))
                    ],
                }))

        self.assertEqual(
            encoded_graph[graph_types.NodeId("cell_1_4")],
            graph_types.GraphNode(
                graph_types.NodeType("cell_LxUD"), {
                    graph_types.OutEdgeType("L_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_1_3"),
                            graph_types.InEdgeType("R_in"))
                    ],
                    graph_types.OutEdgeType("U_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_0_4"),
                            graph_types.InEdgeType("D_in"))
                    ],
                    graph_types.OutEdgeType("D_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_2_4"),
                            graph_types.InEdgeType("U_in"))
                    ],
                }))

        # Check schema validity.
        schema_util.assert_conforms_to_schema(encoded_graph,
                                              maze_schema.build_maze_schema(2))
Esempio n. 14
0
    def build_loop_graph(self):
        """Helper method to build this complex graph, to test graph encodings.

      ┌───────<──┐  ┌───────<─────────<─────┐
      │          │  │                       │
      │    [ao_0]│  ↓[ai_0]                 │
      │          (a0)                       │
      │    [ai_1]↑  │[ao_1]                 │
      │          │  │                       │
      │    [ao_0]│  ↓[ai_0]                 │
      ↓          (a1)                       ↑
      │    [ai_1]↑  │[ao_1]                 │
      │          │  │                       │
      │    [bo_0]│  ↓[bi_0]                 │
      │          ╭──╮───────>[bo_2]────┐    │
      │          │b0│───────>[bo_2]──┐ │    │
      │          │  │<──[bi_0]─────<─┘ │    │
      │          ╰──╯<──[bi_0]─────<─┐ │    │
      │    [bi_0]↑  │[bo_1]          │ │    ↑
      │          │  │                │ │    │
      │    [bo_0]│  ↓[bi_0]          │ │    │
      │          ╭──╮───────>[bo_2]──┘ │    │
      │          │b1│───────>[bo_2]──┐ │    │
      ↓          │  │<──[bi_0]─────<─┘ │    │
      │          ╰──╯<──[bi_0]─────<───┘    │
      │    [bi_0]↑  │[bo_1]                 │
      │          │  │                       │
      └───────>──┘  └───────>─────────>─────┘

    Returns:
      Tuple (schema, graph) for the above structure.
    """
        a = graph_types.NodeType("a")
        b = graph_types.NodeType("b")

        ai_0 = graph_types.InEdgeType("ai_0")
        ai_1 = graph_types.InEdgeType("ai_1")
        bi_0 = graph_types.InEdgeType("bi_0")
        bi_0 = graph_types.InEdgeType("bi_0")

        ao_0 = graph_types.OutEdgeType("ao_0")
        ao_1 = graph_types.OutEdgeType("ao_1")
        bo_0 = graph_types.OutEdgeType("bo_0")
        bo_1 = graph_types.OutEdgeType("bo_1")
        bo_2 = graph_types.OutEdgeType("bo_2")

        a0 = graph_types.NodeId("a0")
        a1 = graph_types.NodeId("a1")
        b0 = graph_types.NodeId("b0")
        b1 = graph_types.NodeId("b1")

        schema = {
            a:
            graph_types.NodeSchema(in_edges=[ai_0, ai_1],
                                   out_edges=[ao_0, ao_1]),
            b:
            graph_types.NodeSchema(in_edges=[bi_0],
                                   out_edges=[bo_0, bo_1, bo_2]),
        }
        test_graph = {
            a0:
            graph_types.GraphNode(
                a, {
                    ao_0: [graph_types.InputTaggedNode(b1, bi_0)],
                    ao_1: [graph_types.InputTaggedNode(a1, ai_0)]
                }),
            a1:
            graph_types.GraphNode(
                a, {
                    ao_0: [graph_types.InputTaggedNode(a0, ai_1)],
                    ao_1: [graph_types.InputTaggedNode(b0, bi_0)]
                }),
            b0:
            graph_types.GraphNode(
                b, {
                    bo_0: [graph_types.InputTaggedNode(a1, ai_1)],
                    bo_1: [graph_types.InputTaggedNode(b1, bi_0)],
                    bo_2: [
                        graph_types.InputTaggedNode(b0, bi_0),
                        graph_types.InputTaggedNode(b1, bi_0)
                    ]
                }),
            b1:
            graph_types.GraphNode(
                b, {
                    bo_0: [graph_types.InputTaggedNode(b0, bi_0)],
                    bo_1: [graph_types.InputTaggedNode(a0, ai_0)],
                    bo_2: [
                        graph_types.InputTaggedNode(b0, bi_0),
                        graph_types.InputTaggedNode(b1, bi_0)
                    ]
                }),
        }
        return schema, test_graph
Esempio n. 15
0
    def test_constructor_actions_nodes_routes(self):
        builder = automaton_builder.AutomatonBuilder(
            self.build_simple_schema(), with_backtrack=False, with_fail=True)

        self.assertEqual(
            set(builder.special_actions), {
                automaton_builder.SpecialActions.FINISH,
                automaton_builder.SpecialActions.FAIL
            })

        self.assertEqual(
            set(builder.node_types),
            {graph_types.NodeType("a"),
             graph_types.NodeType("b")})

        self.assertEqual(
            set(builder.in_route_types), {
                automaton_builder.InRouteType(
                    graph_types.NodeType("a"),
                    automaton_builder.SOURCE_INITIAL),
                automaton_builder.InRouteType(graph_types.NodeType("a"),
                                              graph_types.InEdgeType("ai_0")),
                automaton_builder.InRouteType(graph_types.NodeType("a"),
                                              graph_types.InEdgeType("ai_1")),
                automaton_builder.InRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL),
                automaton_builder.InRouteType(graph_types.NodeType("b"),
                                              graph_types.InEdgeType("bi_0")),
            })

        self.assertEqual(
            set(builder.in_out_route_types), {
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"),
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"), graph_types.InEdgeType("ai_1"),
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("bo_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("bo_1")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
                    graph_types.OutEdgeType("bo_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
                    graph_types.OutEdgeType("bo_1")),
            })
    def test_automaton_layer_abstract_init(self, shared, variant_weights,
                                           use_gate, estimator_type, **kwargs):
        # Create a simple schema and empty encoded graph.
        schema = {
            graph_types.NodeType("a"):
            graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("ai_0")],
                                   out_edges=[graph_types.OutEdgeType("ao_0")
                                              ]),
        }
        builder = automaton_builder.AutomatonBuilder(schema)
        encoded_graph = automaton_builder.EncodedGraph(
            initial_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((128, 1), dtype=jnp.int32),
                output_indices=jnp.zeros((128, 2), dtype=jnp.int32),
                values=jnp.zeros((128, ), dtype=jnp.float32),
            ),
            initial_to_special=jnp.zeros((32, ), dtype=jnp.int32),
            in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((128, 1), dtype=jnp.int32),
                output_indices=jnp.zeros((128, 2), dtype=jnp.int32),
                values=jnp.zeros((128, ), dtype=jnp.float32),
            ),
            in_tagged_to_special=jnp.zeros((64, ), dtype=jnp.int32),
            in_tagged_node_indices=jnp.zeros((64, ), dtype=jnp.int32),
        )

        # Make sure the layer can be initialized and applied within a model.
        # This model is fairly simple; it just pretends that the encoded graph and
        # variants depend on the input.
        class TestModel(flax.deprecated.nn.Module):
            def apply(self, dummy_ignored):
                abstract_encoded_graph = jax.tree_map(
                    lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph)
                abstract_variant_weights = jax.tree_map(
                    lambda y: jax.lax.tie_in(dummy_ignored, y),
                    variant_weights())
                return automaton_layer.FiniteStateGraphAutomaton(
                    encoded_graph=abstract_encoded_graph,
                    variant_weights=abstract_variant_weights,
                    dynamic_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=32, num_input_tagged_nodes=64),
                    static_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=32, num_input_tagged_nodes=64),
                    builder=builder,
                    num_out_edges=3,
                    num_intermediate_states=4,
                    share_states_across_edges=shared,
                    use_gate_parameterization=use_gate,
                    estimator_type=estimator_type,
                    name="the_layer",
                    **kwargs)

        with side_outputs.collect_side_outputs() as side:
            with flax.deprecated.nn.stochastic(jax.random.PRNGKey(0)):
                # For some reason init_by_shape breaks the custom_vjp?
                abstract_out, unused_params = TestModel.init(
                    jax.random.PRNGKey(1234), jnp.zeros((), jnp.float32))

        del unused_params
        self.assertEqual(abstract_out.shape, (3, 32, 32))

        if estimator_type == "one_sample":
            log_prob_key = "/the_layer/one_sample_log_prob_per_edge_per_node"
            self.assertIn(log_prob_key, side)
            self.assertEqual(side[log_prob_key].shape, (3, 32))
Esempio n. 17
0
def build_ast_graph_schema(ast_spec):
    """Builds a graph schema for an AST.

  This logic is described in Appendix B.1 of the paper.

  Each AST node becomes a new node type, whose edges are determined by its
  fields (along with whether it has a parent). Additionally, we generate helper
  node types for each grammar category that appears as a sequence (i.e. one
  for sequences of statments, another for sequences of expressions).

  Nodes with a parent get two edge types for that parent:
  - (in) "parent_in": the edge from the parent to this node
  - (out) "parent_out": the edge from this node to its parent

  ONE_CHILD fields become two edge types:
  - (in) "{field_name}_in": the edge from the child to this node
  - (out) "{field_name}_out": the edge from this node to the child

  OPTIONAL_CHILD fields become three edge types (to account for the
  invariant where every outgoing edge type must have at least one edge):
  - (in) "{field_name}_in": the edge from the child to this node (if it exists)
  - (in) "{field_name}_missing": a loopback edge from this node to itself, used
      as a sentinel value for when the child is missing
  - (out) "{field_name}_out": if there is a child of this type, this edge
      connects to that child; otherwise, it is a loopback edge to this node
      with incoming type "{field_name}_missing".

  NONEMPTY_SEQUENCE or SEQUENCE fields become 4 or 5 edge types:
  - (in) "{field_name}_in": used for EVERY edge from a child-item-helper to this
      node
  - (in) "{field_name}_missing": loopback edges for missing outgoing edges
      (only for SEQUENCE)
  - (out) "{field_name}_all": edges to every child-item-helper, or a loop back
      to "{field_name}_missing" if the sequence is empty
  - (out) "{field_name}_first": edges to the first child-item-helper, or a loop
      back to "{field_name}_missing" if the sequence is empty
  - (out) "{field_name}_last": edges to the last child-item-helper, or a loop
      back to "{field_name}_missing" if the sequence is empty

  NO_CHILDREN fields must be empty ([] or None) and will throw an error
  otherwise. IGNORE fields will be simply ignored.

  Child item helpers are added between any AST node with a sequence of children
  and the children in that sequence. These helpers have the following edge
  types:
  - "item_{in/out}": between the helper and the child
  - "parent_{in/out}": between the helper and the parent
  - "next_{in/out/missing}": between this helper and the next one in the
      sequence, or a loop from "out" to "missing" if this is the last element
  - "prev_{in/out/missing}": between this helper and the previous one in the
      sequence, or a loop from "out" to "missing" if this is the first element
  There is one helper type for each unique value in `sequence_item_types` across
  all node types.

  Args:
    ast_spec: AST spec to generate a schema for.

  Returns:
    GraphSchema with nodes as described above.
  """
    result = {}
    seen_sequence_categories = set()
    # Build node schemas for each AST node
    for node_type, node_spec in ast_spec.items():
        node_schema = graph_types.NodeSchema([], [])

        # Add possible edge types
        if node_spec.has_parent:
            node_schema.in_edges.append(graph_types.InEdgeType("parent_in"))
            node_schema.out_edges.append(graph_types.OutEdgeType("parent_out"))

        for field, field_type in node_spec.fields.items():
            if field_type in {FieldType.ONE_CHILD, FieldType.OPTIONAL_CHILD}:
                node_schema.in_edges.append(
                    graph_types.InEdgeType(f"{field}_in"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(f"{field}_out"))
                if field_type == FieldType.OPTIONAL_CHILD:
                    node_schema.in_edges.append(
                        graph_types.InEdgeType(f"{field}_missing"))
            elif field_type in {
                    FieldType.SEQUENCE, FieldType.NONEMPTY_SEQUENCE
            }:
                seen_sequence_categories.add(
                    node_spec.sequence_item_types[field])
                node_schema.in_edges.append(
                    graph_types.InEdgeType(f"{field}_in"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(f"{field}_out_all"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(f"{field}_out_first"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(f"{field}_out_last"))
                if field_type == FieldType.SEQUENCE:
                    node_schema.in_edges.append(
                        graph_types.InEdgeType(f"{field}_missing"))
            elif field_type in {FieldType.NO_CHILDREN, FieldType.IGNORE}:
                # No edges for these fields.
                pass
            else:
                raise ValueError(f"Unexpected field type {field_type}")

        result[graph_types.NodeType(node_type)] = node_schema

    # Build node schemas for each category helper
    for category in sorted(seen_sequence_categories):
        helper_type = graph_types.NodeType(f"{category}-seq-helper")
        assert helper_type not in result
        node_schema = graph_types.NodeSchema(
            in_edges=[
                graph_types.InEdgeType("parent_in"),
                graph_types.InEdgeType("item_in"),
                graph_types.InEdgeType("next_in"),
                graph_types.InEdgeType("next_missing"),
                graph_types.InEdgeType("prev_in"),
                graph_types.InEdgeType("prev_missing")
            ],
            out_edges=[
                graph_types.OutEdgeType("parent_out"),
                graph_types.OutEdgeType("item_out"),
                graph_types.OutEdgeType("next_out"),
                graph_types.OutEdgeType("prev_out")
            ])
        result[helper_type] = node_schema

    return result