コード例 #1
0
 def build_doubly_linked_list_graph(self, length):
     """Helper method to build a doubly-linked-list graph and schema."""
     schema = {
         graph_types.NodeType("node"):
         graph_types.NodeSchema(in_edges=[
             graph_types.InEdgeType("next_in"),
             graph_types.InEdgeType("prev_in"),
         ],
                                out_edges=[
                                    graph_types.OutEdgeType("next_out"),
                                    graph_types.OutEdgeType("prev_out"),
                                ])
     }
     graph = {}
     for i in range(length):
         graph[graph_types.NodeId(str(i))] = graph_types.GraphNode(
             graph_types.NodeType("node"), {
                 graph_types.OutEdgeType("next_out"): [
                     graph_types.InputTaggedNode(
                         node_id=graph_types.NodeId(str((i + 1) % length)),
                         in_edge=graph_types.InEdgeType("prev_in"))
                 ],
                 graph_types.OutEdgeType("prev_out"): [
                     graph_types.InputTaggedNode(
                         node_id=graph_types.NodeId(str((i - 1) % length)),
                         in_edge=graph_types.InEdgeType("next_in"))
                 ]
             })
     return schema, graph
コード例 #2
0
  def test_routing_gates_to_probs(self):
    builder = automaton_builder.AutomatonBuilder(self.build_simple_schema())

    # [variants, in_out_routes, fsm_states, fsm_states]
    # [variants, in_routes, fsm_states]
    move_gates = np.full([3, len(builder.in_out_route_types), 2, 2], 0.5)
    accept_gates = np.full([3, len(builder.in_route_types), 2], 0.5)
    backtrack_gates = np.full([3, len(builder.in_route_types), 2], 0.5)

    # Set one distribution to sum to more than 1.
    idx_d1_move1 = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
            graph_types.OutEdgeType("bo_0"))]
    move_gates[0, idx_d1_move1, 0, :] = [.2, .3]
    idx_d1_move2 = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
            graph_types.OutEdgeType("bo_1"))]
    move_gates[0, idx_d1_move2, 0, :] = [.4, .5]
    idx_d1_special = builder.in_route_type_to_index[
        automaton_builder.InRouteType(
            graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"))]
    accept_gates[0, idx_d1_special, 0] = .6
    backtrack_gates[0, idx_d1_special, 0] = .3

    # Set another to sum to less than 1.
    idx_d2_move = builder.in_out_route_type_to_index[
        automaton_builder.InOutRouteType(
            graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"),
            graph_types.OutEdgeType("ao_0"))]
    move_gates[2, idx_d2_move, 1, :] = [.1, .2]
    idx_d2_special = builder.in_route_type_to_index[
        automaton_builder.InRouteType(
            graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"))]
    accept_gates[2, idx_d2_special, 1] = .3
    backtrack_gates[2, idx_d2_special, 1] = .75

    routing_gates = automaton_builder.RoutingGateParams(
        move_gates=jax.scipy.special.logit(move_gates),
        accept_gates=jax.scipy.special.logit(accept_gates),
        backtrack_gates=jax.scipy.special.logit(backtrack_gates))
    routing_probs = builder.routing_gates_to_probs(routing_gates)

    # Check probabilities for first distribution: should divide evenly.
    np.testing.assert_allclose(routing_probs.move[0, idx_d1_move1, 0, :],
                               np.array([.2, .3]) / 2.0)
    np.testing.assert_allclose(routing_probs.move[0, idx_d1_move2, 0, :],
                               np.array([.4, .5]) / 2.0)
    np.testing.assert_allclose(routing_probs.special[0, idx_d1_special, 0, :],
                               np.array([.6, 0, 0]) / 2.0)

    # Check probabilities for second distribution: should assign remainder to
    # backtrack and fail.
    np.testing.assert_allclose(routing_probs.move[2, idx_d2_move, 1, :],
                               np.array([.1, .2]))
    np.testing.assert_allclose(routing_probs.special[2, idx_d2_special, 1, :],
                               np.array([.3, .3, .1]))
コード例 #3
0
    def test_all_input_tagged_nodes(self):
        # (note: python3 dicts maintain order, so B2 comes before B1)
        graph = {
            graph_types.NodeId("A"):
            graph_types.GraphNode(
                graph_types.NodeType("a"), {
                    graph_types.OutEdgeType("ao_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B1"),
                            graph_types.InEdgeType("bi_1")),
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ]
                }),
            graph_types.NodeId("B2"):
            graph_types.GraphNode(
                graph_types.NodeType("b"), {
                    graph_types.OutEdgeType("bo_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ],
                    graph_types.OutEdgeType("bo_1"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B1"),
                            graph_types.InEdgeType("bi_0"))
                    ]
                }),
            graph_types.NodeId("B1"):
            graph_types.GraphNode(
                graph_types.NodeType("b"), {
                    graph_types.OutEdgeType("bo_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ],
                    graph_types.OutEdgeType("bo_1"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B2"),
                            graph_types.InEdgeType("bi_0"))
                    ]
                }),
        }
        expected_itns = [
            graph_types.InputTaggedNode(graph_types.NodeId("A"),
                                        graph_types.InEdgeType("ai_1")),
            graph_types.InputTaggedNode(graph_types.NodeId("B2"),
                                        graph_types.InEdgeType("bi_0")),
            graph_types.InputTaggedNode(graph_types.NodeId("B1"),
                                        graph_types.InEdgeType("bi_0")),
            graph_types.InputTaggedNode(graph_types.NodeId("B1"),
                                        graph_types.InEdgeType("bi_1")),
        ]

        actual_itns = schema_util.all_input_tagged_nodes(graph)
        self.assertEqual(actual_itns, expected_itns)
コード例 #4
0
 def build_simple_schema(self):
     return {
         graph_types.NodeType("a"):
         graph_types.NodeSchema(in_edges=[
             graph_types.InEdgeType("ai_0"),
             graph_types.InEdgeType("ai_1")
         ],
                                out_edges=[graph_types.OutEdgeType("ao_0")
                                           ]),
         graph_types.NodeType("b"):
         graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")],
                                out_edges=[
                                    graph_types.OutEdgeType("bo_0"),
                                    graph_types.OutEdgeType("bo_1")
                                ]),
     }
コード例 #5
0
 def connect(from_id, out_type, in_type, to_id):
     """Adds directed edges to the graph."""
     out_type = graph_types.OutEdgeType(out_type)
     in_type = graph_types.InEdgeType(in_type)
     if out_type not in result[from_id].out_edges:
         result[from_id].out_edges[out_type] = []
     result[from_id].out_edges[out_type].append(
         graph_types.InputTaggedNode(node_id=to_id, in_edge=in_type))
コード例 #6
0
 def test_does_not_conform_to_schema(self, graph, expected_error):
     test_schema = {
         graph_types.NodeType("a"):
         graph_types.NodeSchema(in_edges=[
             graph_types.InEdgeType("ai_0"),
             graph_types.InEdgeType("ai_1")
         ],
                                out_edges=[graph_types.OutEdgeType("ao_0")
                                           ]),
         graph_types.NodeType("b"):
         graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")],
                                out_edges=[
                                    graph_types.OutEdgeType("bo_0"),
                                    graph_types.OutEdgeType("bo_1")
                                ]),
     }
     # pylint: disable=g-error-prone-assert-raises
     with self.assertRaisesRegex(ValueError, expected_error):
         schema_util.assert_conforms_to_schema(graph, test_schema)
コード例 #7
0
    def test_constructor_actions_nodes_routes(self):
        builder = automaton_builder.AutomatonBuilder(
            self.build_simple_schema(), with_backtrack=False, with_fail=True)

        self.assertEqual(
            set(builder.special_actions), {
                automaton_builder.SpecialActions.FINISH,
                automaton_builder.SpecialActions.FAIL
            })

        self.assertEqual(
            set(builder.node_types),
            {graph_types.NodeType("a"),
             graph_types.NodeType("b")})

        self.assertEqual(
            set(builder.in_route_types), {
                automaton_builder.InRouteType(
                    graph_types.NodeType("a"),
                    automaton_builder.SOURCE_INITIAL),
                automaton_builder.InRouteType(graph_types.NodeType("a"),
                                              graph_types.InEdgeType("ai_0")),
                automaton_builder.InRouteType(graph_types.NodeType("a"),
                                              graph_types.InEdgeType("ai_1")),
                automaton_builder.InRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL),
                automaton_builder.InRouteType(graph_types.NodeType("b"),
                                              graph_types.InEdgeType("bi_0")),
            })

        self.assertEqual(
            set(builder.in_out_route_types), {
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"),
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("a"), graph_types.InEdgeType("ai_1"),
                    graph_types.OutEdgeType("ao_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("bo_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"),
                    automaton_builder.SOURCE_INITIAL,
                    graph_types.OutEdgeType("bo_1")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
                    graph_types.OutEdgeType("bo_0")),
                automaton_builder.InOutRouteType(
                    graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"),
                    graph_types.OutEdgeType("bo_1")),
            })
コード例 #8
0
def encode_maze(maze):
    """Encode a boolean mask array as a graph.

  We assume a [row, col]-indexed coordinate system, oriented so that going right
  corresponds to changing indexes as (0, +1), and going down corresponds to
  changing indies as (+1, 0).

  Args:
    maze: Maze, as a boolean array <bool[width, height]>, where True corresponds
      to cells that should be nodes in the graph.

  Returns:
    Encoded graph, along with a list of the coordinate indices of each cell in
    the graph.
  """
    # (direction, reverse direction, (dr, dc))
    shifts = [
        ("L", "R", (0, -1)),
        ("R", "L", (0, 1)),
        ("U", "D", (-1, 0)),
        ("D", "U", (1, 0)),
    ]
    graph = {}
    idx_to_cell = []
    for i in range(maze.shape[0]):
        for j in range(maze.shape[1]):
            if maze[i, j]:
                typename = "cell_"
                out_edges = {}
                for name, in_name, (di, dj) in shifts:
                    ni = i + di
                    nj = j + dj
                    if (0 <= ni < maze.shape[0] and 0 <= nj < maze.shape[1]
                            and maze[ni, nj]):
                        typename = typename + name
                        out_edges[graph_types.OutEdgeType(f"{name}_out")] = [
                            graph_types.InputTaggedNode(
                                graph_types.NodeId(f"cell_{ni}_{nj}"),
                                graph_types.InEdgeType(f"{in_name}_in"))
                        ]
                    else:
                        typename = typename + "x"
                graph[graph_types.NodeId(
                    f"cell_{i}_{j}")] = graph_types.GraphNode(
                        graph_types.NodeType(typename), out_edges)
                idx_to_cell.append((i, j))

    return graph, idx_to_cell
コード例 #9
0
ファイル: maze_task.py プロジェクト: tallamjr/google-research
def maze_primitive_edges(maze_graph):
    """Build a graph bundle for a given maze.

  Args:
    maze_graph: Encoded graph representing the maze.

  Returns:
    List of edges corresponding to primitive actions in the maze.
  """
    primitives = []
    for node_id, node_info in maze_graph.items():
        for i, direction in enumerate(DIRECTION_ORDERING):
            out_key = graph_types.OutEdgeType(f"{direction}_out")
            if out_key in node_info.out_edges:
                dest, = node_info.out_edges[out_key]
                primitives.append((node_id, dest.node_id, i))
            else:
                primitives.append((node_id, node_id, i))

    return primitives
コード例 #10
0
    def test_conforms_to_schema(self):
        test_schema = {
            graph_types.NodeType("a"):
            graph_types.NodeSchema(in_edges=[
                graph_types.InEdgeType("ai_0"),
                graph_types.InEdgeType("ai_1")
            ],
                                   out_edges=[graph_types.OutEdgeType("ao_0")
                                              ]),
            graph_types.NodeType("b"):
            graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")],
                                   out_edges=[
                                       graph_types.OutEdgeType("bo_0"),
                                       graph_types.OutEdgeType("bo_1")
                                   ]),
        }

        # Valid graph
        test_graph = {
            graph_types.NodeId("A"):
            graph_types.GraphNode(
                graph_types.NodeType("a"), {
                    graph_types.OutEdgeType("ao_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B"),
                            graph_types.InEdgeType("bi_0")),
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ]
                }),
            graph_types.NodeId("B"):
            graph_types.GraphNode(
                graph_types.NodeType("b"), {
                    graph_types.OutEdgeType("bo_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ],
                    graph_types.OutEdgeType("bo_1"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B"),
                            graph_types.InEdgeType("bi_0"))
                    ]
                })
        }
        schema_util.assert_conforms_to_schema(test_graph, test_schema)
コード例 #11
0
def build_maze_schema(min_neighbors):
    """Build a schema for a 2d gridworld maze.

  A node type like "cell_LxUD" corresponds to a cell that has neighbors in the
  left, up, and down directions. The in and out edges for this cell would be
  "L_out", "U_out", "D_out", "L_in", "U_in", "D_in".

  Args:
    min_neighbors: Minimum number of neighbors a grid cell will have. For
      instance, if min_neighbors=2 then the schema will only have types for
      cells that are connected to at least 2 other cells.

  Returns:
    Schema describing the maze.
  """
    maze_schema = {}
    for (has_l, has_r, has_u, has_d) in itertools.product([False, True],
                                                          repeat=4):
        if np.count_nonzero([has_l, has_r, has_u, has_d]) < min_neighbors:
            continue

        name = (f"cell_{'L' if has_l else 'x'}{'R' if has_r else 'x'}"
                f"{'U' if has_u else 'x'}{'D' if has_d else 'x'}")
        node_schema = graph_types.NodeSchema([], [])
        for direction, used in (
            ("L", has_l),
            ("R", has_r),
            ("U", has_u),
            ("D", has_d),
        ):
            if used:
                node_schema.in_edges.append(
                    graph_types.InEdgeType(direction + "_in"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(direction + "_out"))

        maze_schema[graph_types.NodeType(name)] = node_schema
    return maze_schema
コード例 #12
0
    def test_automaton_layer_abstract_init(self, shared, variant_weights,
                                           use_gate, estimator_type, **kwargs):
        # Create a simple schema and empty encoded graph.
        schema = {
            graph_types.NodeType("a"):
            graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("ai_0")],
                                   out_edges=[graph_types.OutEdgeType("ao_0")
                                              ]),
        }
        builder = automaton_builder.AutomatonBuilder(schema)
        encoded_graph = automaton_builder.EncodedGraph(
            initial_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((128, 1), dtype=jnp.int32),
                output_indices=jnp.zeros((128, 2), dtype=jnp.int32),
                values=jnp.zeros((128, ), dtype=jnp.float32),
            ),
            initial_to_special=jnp.zeros((32, ), dtype=jnp.int32),
            in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator(
                input_indices=jnp.zeros((128, 1), dtype=jnp.int32),
                output_indices=jnp.zeros((128, 2), dtype=jnp.int32),
                values=jnp.zeros((128, ), dtype=jnp.float32),
            ),
            in_tagged_to_special=jnp.zeros((64, ), dtype=jnp.int32),
            in_tagged_node_indices=jnp.zeros((64, ), dtype=jnp.int32),
        )

        # Make sure the layer can be initialized and applied within a model.
        # This model is fairly simple; it just pretends that the encoded graph and
        # variants depend on the input.
        class TestModel(flax.deprecated.nn.Module):
            def apply(self, dummy_ignored):
                abstract_encoded_graph = jax.tree_map(
                    lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph)
                abstract_variant_weights = jax.tree_map(
                    lambda y: jax.lax.tie_in(dummy_ignored, y),
                    variant_weights())
                return automaton_layer.FiniteStateGraphAutomaton(
                    encoded_graph=abstract_encoded_graph,
                    variant_weights=abstract_variant_weights,
                    dynamic_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=32, num_input_tagged_nodes=64),
                    static_metadata=automaton_builder.EncodedGraphMetadata(
                        num_nodes=32, num_input_tagged_nodes=64),
                    builder=builder,
                    num_out_edges=3,
                    num_intermediate_states=4,
                    share_states_across_edges=shared,
                    use_gate_parameterization=use_gate,
                    estimator_type=estimator_type,
                    name="the_layer",
                    **kwargs)

        with side_outputs.collect_side_outputs() as side:
            with flax.deprecated.nn.stochastic(jax.random.PRNGKey(0)):
                # For some reason init_by_shape breaks the custom_vjp?
                abstract_out, unused_params = TestModel.init(
                    jax.random.PRNGKey(1234), jnp.zeros((), jnp.float32))

        del unused_params
        self.assertEqual(abstract_out.shape, (3, 32, 32))

        if estimator_type == "one_sample":
            log_prob_key = "/the_layer/one_sample_log_prob_per_edge_per_node"
            self.assertIn(log_prob_key, side)
            self.assertEqual(side[log_prob_key].shape, (3, 32))
コード例 #13
0
class SchemaUtilTest(parameterized.TestCase):
    def test_conforms_to_schema(self):
        test_schema = {
            graph_types.NodeType("a"):
            graph_types.NodeSchema(in_edges=[
                graph_types.InEdgeType("ai_0"),
                graph_types.InEdgeType("ai_1")
            ],
                                   out_edges=[graph_types.OutEdgeType("ao_0")
                                              ]),
            graph_types.NodeType("b"):
            graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")],
                                   out_edges=[
                                       graph_types.OutEdgeType("bo_0"),
                                       graph_types.OutEdgeType("bo_1")
                                   ]),
        }

        # Valid graph
        test_graph = {
            graph_types.NodeId("A"):
            graph_types.GraphNode(
                graph_types.NodeType("a"), {
                    graph_types.OutEdgeType("ao_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B"),
                            graph_types.InEdgeType("bi_0")),
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ]
                }),
            graph_types.NodeId("B"):
            graph_types.GraphNode(
                graph_types.NodeType("b"), {
                    graph_types.OutEdgeType("bo_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ],
                    graph_types.OutEdgeType("bo_1"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B"),
                            graph_types.InEdgeType("bi_0"))
                    ]
                })
        }
        schema_util.assert_conforms_to_schema(test_graph, test_schema)

    @parameterized.named_parameters(
        {
            "testcase_name": "missing_node",
            "graph": {
                graph_types.NodeId("A"):
                graph_types.GraphNode(
                    graph_types.NodeType("a"), {
                        graph_types.OutEdgeType("ao_0"): [
                            graph_types.InputTaggedNode(
                                graph_types.NodeId("B"),
                                graph_types.InEdgeType("bi_0"))
                        ]
                    })
            },
            "expected_error": "Node A has connection to missing node B"
        }, {
            "testcase_name": "bad_node_type",
            "graph": {
                graph_types.NodeId("A"):
                graph_types.GraphNode(
                    graph_types.NodeType("z"), {
                        graph_types.OutEdgeType("ao_0"): [
                            graph_types.InputTaggedNode(
                                graph_types.NodeId("A"),
                                graph_types.InEdgeType("ai_0"))
                        ]
                    })
            },
            "expected_error": "Node A's type z not in schema"
        }, {
            "testcase_name": "missing_out_edge",
            "graph": {
                graph_types.NodeId("A"):
                graph_types.GraphNode(graph_types.NodeType("a"),
                                      {graph_types.OutEdgeType("ao_0"): []})
            },
            "expected_error": "Node A missing out edge of type ao_0"
        }, {
            "testcase_name": "bad_out_edge_type",
            "graph": {
                graph_types.NodeId("A"):
                graph_types.GraphNode(
                    graph_types.NodeType("a"), {
                        graph_types.OutEdgeType("ao_0"): [
                            graph_types.InputTaggedNode(
                                graph_types.NodeId("A"),
                                graph_types.InEdgeType("ai_0"))
                        ],
                        "foo": [
                            graph_types.InputTaggedNode(
                                graph_types.NodeId("A"),
                                graph_types.InEdgeType("ai_0"))
                        ]
                    })
            },
            "expected_error": "Node A has out-edges of invalid type foo"
        }, {
            "testcase_name": "bad_in_edge_type",
            "graph": {
                graph_types.NodeId("A"):
                graph_types.GraphNode(
                    graph_types.NodeType("a"), {
                        graph_types.OutEdgeType("ao_0"): [
                            graph_types.InputTaggedNode(
                                graph_types.NodeId("A"),
                                graph_types.InEdgeType("bar"))
                        ],
                    })
            },
            "expected_error": "Node A has in-edges of invalid type bar"
        })
    def test_does_not_conform_to_schema(self, graph, expected_error):
        test_schema = {
            graph_types.NodeType("a"):
            graph_types.NodeSchema(in_edges=[
                graph_types.InEdgeType("ai_0"),
                graph_types.InEdgeType("ai_1")
            ],
                                   out_edges=[graph_types.OutEdgeType("ao_0")
                                              ]),
            graph_types.NodeType("b"):
            graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")],
                                   out_edges=[
                                       graph_types.OutEdgeType("bo_0"),
                                       graph_types.OutEdgeType("bo_1")
                                   ]),
        }
        # pylint: disable=g-error-prone-assert-raises
        with self.assertRaisesRegex(ValueError, expected_error):
            schema_util.assert_conforms_to_schema(graph, test_schema)
        # pylint: enable=g-error-prone-assert-raises

    def test_all_input_tagged_nodes(self):
        # (note: python3 dicts maintain order, so B2 comes before B1)
        graph = {
            graph_types.NodeId("A"):
            graph_types.GraphNode(
                graph_types.NodeType("a"), {
                    graph_types.OutEdgeType("ao_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B1"),
                            graph_types.InEdgeType("bi_1")),
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ]
                }),
            graph_types.NodeId("B2"):
            graph_types.GraphNode(
                graph_types.NodeType("b"), {
                    graph_types.OutEdgeType("bo_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ],
                    graph_types.OutEdgeType("bo_1"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B1"),
                            graph_types.InEdgeType("bi_0"))
                    ]
                }),
            graph_types.NodeId("B1"):
            graph_types.GraphNode(
                graph_types.NodeType("b"), {
                    graph_types.OutEdgeType("bo_0"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("A"),
                            graph_types.InEdgeType("ai_1"))
                    ],
                    graph_types.OutEdgeType("bo_1"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("B2"),
                            graph_types.InEdgeType("bi_0"))
                    ]
                }),
        }
        expected_itns = [
            graph_types.InputTaggedNode(graph_types.NodeId("A"),
                                        graph_types.InEdgeType("ai_1")),
            graph_types.InputTaggedNode(graph_types.NodeId("B2"),
                                        graph_types.InEdgeType("bi_0")),
            graph_types.InputTaggedNode(graph_types.NodeId("B1"),
                                        graph_types.InEdgeType("bi_0")),
            graph_types.InputTaggedNode(graph_types.NodeId("B1"),
                                        graph_types.InEdgeType("bi_1")),
        ]

        actual_itns = schema_util.all_input_tagged_nodes(graph)
        self.assertEqual(actual_itns, expected_itns)
コード例 #14
0
def build_ast_graph_schema(ast_spec):
    """Builds a graph schema for an AST.

  This logic is described in Appendix B.1 of the paper.

  Each AST node becomes a new node type, whose edges are determined by its
  fields (along with whether it has a parent). Additionally, we generate helper
  node types for each grammar category that appears as a sequence (i.e. one
  for sequences of statments, another for sequences of expressions).

  Nodes with a parent get two edge types for that parent:
  - (in) "parent_in": the edge from the parent to this node
  - (out) "parent_out": the edge from this node to its parent

  ONE_CHILD fields become two edge types:
  - (in) "{field_name}_in": the edge from the child to this node
  - (out) "{field_name}_out": the edge from this node to the child

  OPTIONAL_CHILD fields become three edge types (to account for the
  invariant where every outgoing edge type must have at least one edge):
  - (in) "{field_name}_in": the edge from the child to this node (if it exists)
  - (in) "{field_name}_missing": a loopback edge from this node to itself, used
      as a sentinel value for when the child is missing
  - (out) "{field_name}_out": if there is a child of this type, this edge
      connects to that child; otherwise, it is a loopback edge to this node
      with incoming type "{field_name}_missing".

  NONEMPTY_SEQUENCE or SEQUENCE fields become 4 or 5 edge types:
  - (in) "{field_name}_in": used for EVERY edge from a child-item-helper to this
      node
  - (in) "{field_name}_missing": loopback edges for missing outgoing edges
      (only for SEQUENCE)
  - (out) "{field_name}_all": edges to every child-item-helper, or a loop back
      to "{field_name}_missing" if the sequence is empty
  - (out) "{field_name}_first": edges to the first child-item-helper, or a loop
      back to "{field_name}_missing" if the sequence is empty
  - (out) "{field_name}_last": edges to the last child-item-helper, or a loop
      back to "{field_name}_missing" if the sequence is empty

  NO_CHILDREN fields must be empty ([] or None) and will throw an error
  otherwise. IGNORE fields will be simply ignored.

  Child item helpers are added between any AST node with a sequence of children
  and the children in that sequence. These helpers have the following edge
  types:
  - "item_{in/out}": between the helper and the child
  - "parent_{in/out}": between the helper and the parent
  - "next_{in/out/missing}": between this helper and the next one in the
      sequence, or a loop from "out" to "missing" if this is the last element
  - "prev_{in/out/missing}": between this helper and the previous one in the
      sequence, or a loop from "out" to "missing" if this is the first element
  There is one helper type for each unique value in `sequence_item_types` across
  all node types.

  Args:
    ast_spec: AST spec to generate a schema for.

  Returns:
    GraphSchema with nodes as described above.
  """
    result = {}
    seen_sequence_categories = set()
    # Build node schemas for each AST node
    for node_type, node_spec in ast_spec.items():
        node_schema = graph_types.NodeSchema([], [])

        # Add possible edge types
        if node_spec.has_parent:
            node_schema.in_edges.append(graph_types.InEdgeType("parent_in"))
            node_schema.out_edges.append(graph_types.OutEdgeType("parent_out"))

        for field, field_type in node_spec.fields.items():
            if field_type in {FieldType.ONE_CHILD, FieldType.OPTIONAL_CHILD}:
                node_schema.in_edges.append(
                    graph_types.InEdgeType(f"{field}_in"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(f"{field}_out"))
                if field_type == FieldType.OPTIONAL_CHILD:
                    node_schema.in_edges.append(
                        graph_types.InEdgeType(f"{field}_missing"))
            elif field_type in {
                    FieldType.SEQUENCE, FieldType.NONEMPTY_SEQUENCE
            }:
                seen_sequence_categories.add(
                    node_spec.sequence_item_types[field])
                node_schema.in_edges.append(
                    graph_types.InEdgeType(f"{field}_in"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(f"{field}_out_all"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(f"{field}_out_first"))
                node_schema.out_edges.append(
                    graph_types.OutEdgeType(f"{field}_out_last"))
                if field_type == FieldType.SEQUENCE:
                    node_schema.in_edges.append(
                        graph_types.InEdgeType(f"{field}_missing"))
            elif field_type in {FieldType.NO_CHILDREN, FieldType.IGNORE}:
                # No edges for these fields.
                pass
            else:
                raise ValueError(f"Unexpected field type {field_type}")

        result[graph_types.NodeType(node_type)] = node_schema

    # Build node schemas for each category helper
    for category in sorted(seen_sequence_categories):
        helper_type = graph_types.NodeType(f"{category}-seq-helper")
        assert helper_type not in result
        node_schema = graph_types.NodeSchema(
            in_edges=[
                graph_types.InEdgeType("parent_in"),
                graph_types.InEdgeType("item_in"),
                graph_types.InEdgeType("next_in"),
                graph_types.InEdgeType("next_missing"),
                graph_types.InEdgeType("prev_in"),
                graph_types.InEdgeType("prev_missing")
            ],
            out_edges=[
                graph_types.OutEdgeType("parent_out"),
                graph_types.OutEdgeType("item_out"),
                graph_types.OutEdgeType("next_out"),
                graph_types.OutEdgeType("prev_out")
            ])
        result[helper_type] = node_schema

    return result
コード例 #15
0
    def build_loop_graph(self):
        """Helper method to build this complex graph, to test graph encodings.

      ┌───────<──┐  ┌───────<─────────<─────┐
      │          │  │                       │
      │    [ao_0]│  ↓[ai_0]                 │
      │          (a0)                       │
      │    [ai_1]↑  │[ao_1]                 │
      │          │  │                       │
      │    [ao_0]│  ↓[ai_0]                 │
      ↓          (a1)                       ↑
      │    [ai_1]↑  │[ao_1]                 │
      │          │  │                       │
      │    [bo_0]│  ↓[bi_0]                 │
      │          ╭──╮───────>[bo_2]────┐    │
      │          │b0│───────>[bo_2]──┐ │    │
      │          │  │<──[bi_0]─────<─┘ │    │
      │          ╰──╯<──[bi_0]─────<─┐ │    │
      │    [bi_0]↑  │[bo_1]          │ │    ↑
      │          │  │                │ │    │
      │    [bo_0]│  ↓[bi_0]          │ │    │
      │          ╭──╮───────>[bo_2]──┘ │    │
      │          │b1│───────>[bo_2]──┐ │    │
      ↓          │  │<──[bi_0]─────<─┘ │    │
      │          ╰──╯<──[bi_0]─────<───┘    │
      │    [bi_0]↑  │[bo_1]                 │
      │          │  │                       │
      └───────>──┘  └───────>─────────>─────┘

    Returns:
      Tuple (schema, graph) for the above structure.
    """
        a = graph_types.NodeType("a")
        b = graph_types.NodeType("b")

        ai_0 = graph_types.InEdgeType("ai_0")
        ai_1 = graph_types.InEdgeType("ai_1")
        bi_0 = graph_types.InEdgeType("bi_0")
        bi_0 = graph_types.InEdgeType("bi_0")

        ao_0 = graph_types.OutEdgeType("ao_0")
        ao_1 = graph_types.OutEdgeType("ao_1")
        bo_0 = graph_types.OutEdgeType("bo_0")
        bo_1 = graph_types.OutEdgeType("bo_1")
        bo_2 = graph_types.OutEdgeType("bo_2")

        a0 = graph_types.NodeId("a0")
        a1 = graph_types.NodeId("a1")
        b0 = graph_types.NodeId("b0")
        b1 = graph_types.NodeId("b1")

        schema = {
            a:
            graph_types.NodeSchema(in_edges=[ai_0, ai_1],
                                   out_edges=[ao_0, ao_1]),
            b:
            graph_types.NodeSchema(in_edges=[bi_0],
                                   out_edges=[bo_0, bo_1, bo_2]),
        }
        test_graph = {
            a0:
            graph_types.GraphNode(
                a, {
                    ao_0: [graph_types.InputTaggedNode(b1, bi_0)],
                    ao_1: [graph_types.InputTaggedNode(a1, ai_0)]
                }),
            a1:
            graph_types.GraphNode(
                a, {
                    ao_0: [graph_types.InputTaggedNode(a0, ai_1)],
                    ao_1: [graph_types.InputTaggedNode(b0, bi_0)]
                }),
            b0:
            graph_types.GraphNode(
                b, {
                    bo_0: [graph_types.InputTaggedNode(a1, ai_1)],
                    bo_1: [graph_types.InputTaggedNode(b1, bi_0)],
                    bo_2: [
                        graph_types.InputTaggedNode(b0, bi_0),
                        graph_types.InputTaggedNode(b1, bi_0)
                    ]
                }),
            b1:
            graph_types.GraphNode(
                b, {
                    bo_0: [graph_types.InputTaggedNode(b0, bi_0)],
                    bo_1: [graph_types.InputTaggedNode(a0, ai_0)],
                    bo_2: [
                        graph_types.InputTaggedNode(b0, bi_0),
                        graph_types.InputTaggedNode(b1, bi_0)
                    ]
                }),
        }
        return schema, test_graph
コード例 #16
0
    def test_encode_maze(self):
        """Tests that an encoded maze is correct and matches the schema."""

        maze = np.array([
            [1, 1, 1, 1, 1],
            [1, 0, 1, 1, 1],
            [1, 1, 1, 1, 1],
        ]).astype(bool)

        encoded_graph, coordinates = maze_schema.encode_maze(maze)

        # Check coordinates.
        expected_coords = []
        for r in range(3):
            for c in range(5):
                if (r, c) != (1, 1):
                    expected_coords.append((r, c))

        self.assertEqual(coordinates, expected_coords)

        # Check a few nodes.
        self.assertEqual(
            encoded_graph[graph_types.NodeId("cell_0_0")],
            graph_types.GraphNode(
                graph_types.NodeType("cell_xRxD"), {
                    graph_types.OutEdgeType("R_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_0_1"),
                            graph_types.InEdgeType("L_in"))
                    ],
                    graph_types.OutEdgeType("D_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_1_0"),
                            graph_types.InEdgeType("U_in"))
                    ],
                }))

        self.assertEqual(
            encoded_graph[graph_types.NodeId("cell_1_4")],
            graph_types.GraphNode(
                graph_types.NodeType("cell_LxUD"), {
                    graph_types.OutEdgeType("L_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_1_3"),
                            graph_types.InEdgeType("R_in"))
                    ],
                    graph_types.OutEdgeType("U_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_0_4"),
                            graph_types.InEdgeType("D_in"))
                    ],
                    graph_types.OutEdgeType("D_out"): [
                        graph_types.InputTaggedNode(
                            graph_types.NodeId("cell_2_4"),
                            graph_types.InEdgeType("U_in"))
                    ],
                }))

        # Check schema validity.
        schema_util.assert_conforms_to_schema(encoded_graph,
                                              maze_schema.build_maze_schema(2))