def build_doubly_linked_list_graph(self, length): """Helper method to build a doubly-linked-list graph and schema.""" schema = { graph_types.NodeType("node"): graph_types.NodeSchema(in_edges=[ graph_types.InEdgeType("next_in"), graph_types.InEdgeType("prev_in"), ], out_edges=[ graph_types.OutEdgeType("next_out"), graph_types.OutEdgeType("prev_out"), ]) } graph = {} for i in range(length): graph[graph_types.NodeId(str(i))] = graph_types.GraphNode( graph_types.NodeType("node"), { graph_types.OutEdgeType("next_out"): [ graph_types.InputTaggedNode( node_id=graph_types.NodeId(str((i + 1) % length)), in_edge=graph_types.InEdgeType("prev_in")) ], graph_types.OutEdgeType("prev_out"): [ graph_types.InputTaggedNode( node_id=graph_types.NodeId(str((i - 1) % length)), in_edge=graph_types.InEdgeType("next_in")) ] }) return schema, graph
def test_routing_gates_to_probs(self): builder = automaton_builder.AutomatonBuilder(self.build_simple_schema()) # [variants, in_out_routes, fsm_states, fsm_states] # [variants, in_routes, fsm_states] move_gates = np.full([3, len(builder.in_out_route_types), 2, 2], 0.5) accept_gates = np.full([3, len(builder.in_route_types), 2], 0.5) backtrack_gates = np.full([3, len(builder.in_route_types), 2], 0.5) # Set one distribution to sum to more than 1. idx_d1_move1 = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_0"))] move_gates[0, idx_d1_move1, 0, :] = [.2, .3] idx_d1_move2 = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_1"))] move_gates[0, idx_d1_move2, 0, :] = [.4, .5] idx_d1_special = builder.in_route_type_to_index[ automaton_builder.InRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"))] accept_gates[0, idx_d1_special, 0] = .6 backtrack_gates[0, idx_d1_special, 0] = .3 # Set another to sum to less than 1. idx_d2_move = builder.in_out_route_type_to_index[ automaton_builder.InOutRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"), graph_types.OutEdgeType("ao_0"))] move_gates[2, idx_d2_move, 1, :] = [.1, .2] idx_d2_special = builder.in_route_type_to_index[ automaton_builder.InRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"))] accept_gates[2, idx_d2_special, 1] = .3 backtrack_gates[2, idx_d2_special, 1] = .75 routing_gates = automaton_builder.RoutingGateParams( move_gates=jax.scipy.special.logit(move_gates), accept_gates=jax.scipy.special.logit(accept_gates), backtrack_gates=jax.scipy.special.logit(backtrack_gates)) routing_probs = builder.routing_gates_to_probs(routing_gates) # Check probabilities for first distribution: should divide evenly. np.testing.assert_allclose(routing_probs.move[0, idx_d1_move1, 0, :], np.array([.2, .3]) / 2.0) np.testing.assert_allclose(routing_probs.move[0, idx_d1_move2, 0, :], np.array([.4, .5]) / 2.0) np.testing.assert_allclose(routing_probs.special[0, idx_d1_special, 0, :], np.array([.6, 0, 0]) / 2.0) # Check probabilities for second distribution: should assign remainder to # backtrack and fail. np.testing.assert_allclose(routing_probs.move[2, idx_d2_move, 1, :], np.array([.1, .2])) np.testing.assert_allclose(routing_probs.special[2, idx_d2_special, 1, :], np.array([.3, .3, .1]))
def test_all_input_tagged_nodes(self): # (note: python3 dicts maintain order, so B2 comes before B1) graph = { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B1"), graph_types.InEdgeType("bi_1")), graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ] }), graph_types.NodeId("B2"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B1"), graph_types.InEdgeType("bi_0")) ] }), graph_types.NodeId("B1"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B2"), graph_types.InEdgeType("bi_0")) ] }), } expected_itns = [ graph_types.InputTaggedNode(graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")), graph_types.InputTaggedNode(graph_types.NodeId("B2"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode(graph_types.NodeId("B1"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode(graph_types.NodeId("B1"), graph_types.InEdgeType("bi_1")), ] actual_itns = schema_util.all_input_tagged_nodes(graph) self.assertEqual(actual_itns, expected_itns)
def build_simple_schema(self): return { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[ graph_types.InEdgeType("ai_0"), graph_types.InEdgeType("ai_1") ], out_edges=[graph_types.OutEdgeType("ao_0") ]), graph_types.NodeType("b"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")], out_edges=[ graph_types.OutEdgeType("bo_0"), graph_types.OutEdgeType("bo_1") ]), }
def connect(from_id, out_type, in_type, to_id): """Adds directed edges to the graph.""" out_type = graph_types.OutEdgeType(out_type) in_type = graph_types.InEdgeType(in_type) if out_type not in result[from_id].out_edges: result[from_id].out_edges[out_type] = [] result[from_id].out_edges[out_type].append( graph_types.InputTaggedNode(node_id=to_id, in_edge=in_type))
def test_does_not_conform_to_schema(self, graph, expected_error): test_schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[ graph_types.InEdgeType("ai_0"), graph_types.InEdgeType("ai_1") ], out_edges=[graph_types.OutEdgeType("ao_0") ]), graph_types.NodeType("b"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")], out_edges=[ graph_types.OutEdgeType("bo_0"), graph_types.OutEdgeType("bo_1") ]), } # pylint: disable=g-error-prone-assert-raises with self.assertRaisesRegex(ValueError, expected_error): schema_util.assert_conforms_to_schema(graph, test_schema)
def test_constructor_actions_nodes_routes(self): builder = automaton_builder.AutomatonBuilder( self.build_simple_schema(), with_backtrack=False, with_fail=True) self.assertEqual( set(builder.special_actions), { automaton_builder.SpecialActions.FINISH, automaton_builder.SpecialActions.FAIL }) self.assertEqual( set(builder.node_types), {graph_types.NodeType("a"), graph_types.NodeType("b")}) self.assertEqual( set(builder.in_route_types), { automaton_builder.InRouteType( graph_types.NodeType("a"), automaton_builder.SOURCE_INITIAL), automaton_builder.InRouteType(graph_types.NodeType("a"), graph_types.InEdgeType("ai_0")), automaton_builder.InRouteType(graph_types.NodeType("a"), graph_types.InEdgeType("ai_1")), automaton_builder.InRouteType( graph_types.NodeType("b"), automaton_builder.SOURCE_INITIAL), automaton_builder.InRouteType(graph_types.NodeType("b"), graph_types.InEdgeType("bi_0")), }) self.assertEqual( set(builder.in_out_route_types), { automaton_builder.InOutRouteType( graph_types.NodeType("a"), automaton_builder.SOURCE_INITIAL, graph_types.OutEdgeType("ao_0")), automaton_builder.InOutRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_0"), graph_types.OutEdgeType("ao_0")), automaton_builder.InOutRouteType( graph_types.NodeType("a"), graph_types.InEdgeType("ai_1"), graph_types.OutEdgeType("ao_0")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), automaton_builder.SOURCE_INITIAL, graph_types.OutEdgeType("bo_0")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), automaton_builder.SOURCE_INITIAL, graph_types.OutEdgeType("bo_1")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_0")), automaton_builder.InOutRouteType( graph_types.NodeType("b"), graph_types.InEdgeType("bi_0"), graph_types.OutEdgeType("bo_1")), })
def encode_maze(maze): """Encode a boolean mask array as a graph. We assume a [row, col]-indexed coordinate system, oriented so that going right corresponds to changing indexes as (0, +1), and going down corresponds to changing indies as (+1, 0). Args: maze: Maze, as a boolean array <bool[width, height]>, where True corresponds to cells that should be nodes in the graph. Returns: Encoded graph, along with a list of the coordinate indices of each cell in the graph. """ # (direction, reverse direction, (dr, dc)) shifts = [ ("L", "R", (0, -1)), ("R", "L", (0, 1)), ("U", "D", (-1, 0)), ("D", "U", (1, 0)), ] graph = {} idx_to_cell = [] for i in range(maze.shape[0]): for j in range(maze.shape[1]): if maze[i, j]: typename = "cell_" out_edges = {} for name, in_name, (di, dj) in shifts: ni = i + di nj = j + dj if (0 <= ni < maze.shape[0] and 0 <= nj < maze.shape[1] and maze[ni, nj]): typename = typename + name out_edges[graph_types.OutEdgeType(f"{name}_out")] = [ graph_types.InputTaggedNode( graph_types.NodeId(f"cell_{ni}_{nj}"), graph_types.InEdgeType(f"{in_name}_in")) ] else: typename = typename + "x" graph[graph_types.NodeId( f"cell_{i}_{j}")] = graph_types.GraphNode( graph_types.NodeType(typename), out_edges) idx_to_cell.append((i, j)) return graph, idx_to_cell
def maze_primitive_edges(maze_graph): """Build a graph bundle for a given maze. Args: maze_graph: Encoded graph representing the maze. Returns: List of edges corresponding to primitive actions in the maze. """ primitives = [] for node_id, node_info in maze_graph.items(): for i, direction in enumerate(DIRECTION_ORDERING): out_key = graph_types.OutEdgeType(f"{direction}_out") if out_key in node_info.out_edges: dest, = node_info.out_edges[out_key] primitives.append((node_id, dest.node_id, i)) else: primitives.append((node_id, node_id, i)) return primitives
def test_conforms_to_schema(self): test_schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[ graph_types.InEdgeType("ai_0"), graph_types.InEdgeType("ai_1") ], out_edges=[graph_types.OutEdgeType("ao_0") ]), graph_types.NodeType("b"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")], out_edges=[ graph_types.OutEdgeType("bo_0"), graph_types.OutEdgeType("bo_1") ]), } # Valid graph test_graph = { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ] }), graph_types.NodeId("B"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")) ] }) } schema_util.assert_conforms_to_schema(test_graph, test_schema)
def build_maze_schema(min_neighbors): """Build a schema for a 2d gridworld maze. A node type like "cell_LxUD" corresponds to a cell that has neighbors in the left, up, and down directions. The in and out edges for this cell would be "L_out", "U_out", "D_out", "L_in", "U_in", "D_in". Args: min_neighbors: Minimum number of neighbors a grid cell will have. For instance, if min_neighbors=2 then the schema will only have types for cells that are connected to at least 2 other cells. Returns: Schema describing the maze. """ maze_schema = {} for (has_l, has_r, has_u, has_d) in itertools.product([False, True], repeat=4): if np.count_nonzero([has_l, has_r, has_u, has_d]) < min_neighbors: continue name = (f"cell_{'L' if has_l else 'x'}{'R' if has_r else 'x'}" f"{'U' if has_u else 'x'}{'D' if has_d else 'x'}") node_schema = graph_types.NodeSchema([], []) for direction, used in ( ("L", has_l), ("R", has_r), ("U", has_u), ("D", has_d), ): if used: node_schema.in_edges.append( graph_types.InEdgeType(direction + "_in")) node_schema.out_edges.append( graph_types.OutEdgeType(direction + "_out")) maze_schema[graph_types.NodeType(name)] = node_schema return maze_schema
def test_automaton_layer_abstract_init(self, shared, variant_weights, use_gate, estimator_type, **kwargs): # Create a simple schema and empty encoded graph. schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("ai_0")], out_edges=[graph_types.OutEdgeType("ao_0") ]), } builder = automaton_builder.AutomatonBuilder(schema) encoded_graph = automaton_builder.EncodedGraph( initial_to_in_tagged=sparse_operator.SparseCoordOperator( input_indices=jnp.zeros((128, 1), dtype=jnp.int32), output_indices=jnp.zeros((128, 2), dtype=jnp.int32), values=jnp.zeros((128, ), dtype=jnp.float32), ), initial_to_special=jnp.zeros((32, ), dtype=jnp.int32), in_tagged_to_in_tagged=sparse_operator.SparseCoordOperator( input_indices=jnp.zeros((128, 1), dtype=jnp.int32), output_indices=jnp.zeros((128, 2), dtype=jnp.int32), values=jnp.zeros((128, ), dtype=jnp.float32), ), in_tagged_to_special=jnp.zeros((64, ), dtype=jnp.int32), in_tagged_node_indices=jnp.zeros((64, ), dtype=jnp.int32), ) # Make sure the layer can be initialized and applied within a model. # This model is fairly simple; it just pretends that the encoded graph and # variants depend on the input. class TestModel(flax.deprecated.nn.Module): def apply(self, dummy_ignored): abstract_encoded_graph = jax.tree_map( lambda y: jax.lax.tie_in(dummy_ignored, y), encoded_graph) abstract_variant_weights = jax.tree_map( lambda y: jax.lax.tie_in(dummy_ignored, y), variant_weights()) return automaton_layer.FiniteStateGraphAutomaton( encoded_graph=abstract_encoded_graph, variant_weights=abstract_variant_weights, dynamic_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=32, num_input_tagged_nodes=64), static_metadata=automaton_builder.EncodedGraphMetadata( num_nodes=32, num_input_tagged_nodes=64), builder=builder, num_out_edges=3, num_intermediate_states=4, share_states_across_edges=shared, use_gate_parameterization=use_gate, estimator_type=estimator_type, name="the_layer", **kwargs) with side_outputs.collect_side_outputs() as side: with flax.deprecated.nn.stochastic(jax.random.PRNGKey(0)): # For some reason init_by_shape breaks the custom_vjp? abstract_out, unused_params = TestModel.init( jax.random.PRNGKey(1234), jnp.zeros((), jnp.float32)) del unused_params self.assertEqual(abstract_out.shape, (3, 32, 32)) if estimator_type == "one_sample": log_prob_key = "/the_layer/one_sample_log_prob_per_edge_per_node" self.assertIn(log_prob_key, side) self.assertEqual(side[log_prob_key].shape, (3, 32))
class SchemaUtilTest(parameterized.TestCase): def test_conforms_to_schema(self): test_schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[ graph_types.InEdgeType("ai_0"), graph_types.InEdgeType("ai_1") ], out_edges=[graph_types.OutEdgeType("ao_0") ]), graph_types.NodeType("b"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")], out_edges=[ graph_types.OutEdgeType("bo_0"), graph_types.OutEdgeType("bo_1") ]), } # Valid graph test_graph = { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ] }), graph_types.NodeId("B"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")) ] }) } schema_util.assert_conforms_to_schema(test_graph, test_schema) @parameterized.named_parameters( { "testcase_name": "missing_node", "graph": { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")) ] }) }, "expected_error": "Node A has connection to missing node B" }, { "testcase_name": "bad_node_type", "graph": { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("z"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_0")) ] }) }, "expected_error": "Node A's type z not in schema" }, { "testcase_name": "missing_out_edge", "graph": { graph_types.NodeId("A"): graph_types.GraphNode(graph_types.NodeType("a"), {graph_types.OutEdgeType("ao_0"): []}) }, "expected_error": "Node A missing out edge of type ao_0" }, { "testcase_name": "bad_out_edge_type", "graph": { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_0")) ], "foo": [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_0")) ] }) }, "expected_error": "Node A has out-edges of invalid type foo" }, { "testcase_name": "bad_in_edge_type", "graph": { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("bar")) ], }) }, "expected_error": "Node A has in-edges of invalid type bar" }) def test_does_not_conform_to_schema(self, graph, expected_error): test_schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[ graph_types.InEdgeType("ai_0"), graph_types.InEdgeType("ai_1") ], out_edges=[graph_types.OutEdgeType("ao_0") ]), graph_types.NodeType("b"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")], out_edges=[ graph_types.OutEdgeType("bo_0"), graph_types.OutEdgeType("bo_1") ]), } # pylint: disable=g-error-prone-assert-raises with self.assertRaisesRegex(ValueError, expected_error): schema_util.assert_conforms_to_schema(graph, test_schema) # pylint: enable=g-error-prone-assert-raises def test_all_input_tagged_nodes(self): # (note: python3 dicts maintain order, so B2 comes before B1) graph = { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B1"), graph_types.InEdgeType("bi_1")), graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ] }), graph_types.NodeId("B2"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B1"), graph_types.InEdgeType("bi_0")) ] }), graph_types.NodeId("B1"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B2"), graph_types.InEdgeType("bi_0")) ] }), } expected_itns = [ graph_types.InputTaggedNode(graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")), graph_types.InputTaggedNode(graph_types.NodeId("B2"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode(graph_types.NodeId("B1"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode(graph_types.NodeId("B1"), graph_types.InEdgeType("bi_1")), ] actual_itns = schema_util.all_input_tagged_nodes(graph) self.assertEqual(actual_itns, expected_itns)
def build_ast_graph_schema(ast_spec): """Builds a graph schema for an AST. This logic is described in Appendix B.1 of the paper. Each AST node becomes a new node type, whose edges are determined by its fields (along with whether it has a parent). Additionally, we generate helper node types for each grammar category that appears as a sequence (i.e. one for sequences of statments, another for sequences of expressions). Nodes with a parent get two edge types for that parent: - (in) "parent_in": the edge from the parent to this node - (out) "parent_out": the edge from this node to its parent ONE_CHILD fields become two edge types: - (in) "{field_name}_in": the edge from the child to this node - (out) "{field_name}_out": the edge from this node to the child OPTIONAL_CHILD fields become three edge types (to account for the invariant where every outgoing edge type must have at least one edge): - (in) "{field_name}_in": the edge from the child to this node (if it exists) - (in) "{field_name}_missing": a loopback edge from this node to itself, used as a sentinel value for when the child is missing - (out) "{field_name}_out": if there is a child of this type, this edge connects to that child; otherwise, it is a loopback edge to this node with incoming type "{field_name}_missing". NONEMPTY_SEQUENCE or SEQUENCE fields become 4 or 5 edge types: - (in) "{field_name}_in": used for EVERY edge from a child-item-helper to this node - (in) "{field_name}_missing": loopback edges for missing outgoing edges (only for SEQUENCE) - (out) "{field_name}_all": edges to every child-item-helper, or a loop back to "{field_name}_missing" if the sequence is empty - (out) "{field_name}_first": edges to the first child-item-helper, or a loop back to "{field_name}_missing" if the sequence is empty - (out) "{field_name}_last": edges to the last child-item-helper, or a loop back to "{field_name}_missing" if the sequence is empty NO_CHILDREN fields must be empty ([] or None) and will throw an error otherwise. IGNORE fields will be simply ignored. Child item helpers are added between any AST node with a sequence of children and the children in that sequence. These helpers have the following edge types: - "item_{in/out}": between the helper and the child - "parent_{in/out}": between the helper and the parent - "next_{in/out/missing}": between this helper and the next one in the sequence, or a loop from "out" to "missing" if this is the last element - "prev_{in/out/missing}": between this helper and the previous one in the sequence, or a loop from "out" to "missing" if this is the first element There is one helper type for each unique value in `sequence_item_types` across all node types. Args: ast_spec: AST spec to generate a schema for. Returns: GraphSchema with nodes as described above. """ result = {} seen_sequence_categories = set() # Build node schemas for each AST node for node_type, node_spec in ast_spec.items(): node_schema = graph_types.NodeSchema([], []) # Add possible edge types if node_spec.has_parent: node_schema.in_edges.append(graph_types.InEdgeType("parent_in")) node_schema.out_edges.append(graph_types.OutEdgeType("parent_out")) for field, field_type in node_spec.fields.items(): if field_type in {FieldType.ONE_CHILD, FieldType.OPTIONAL_CHILD}: node_schema.in_edges.append( graph_types.InEdgeType(f"{field}_in")) node_schema.out_edges.append( graph_types.OutEdgeType(f"{field}_out")) if field_type == FieldType.OPTIONAL_CHILD: node_schema.in_edges.append( graph_types.InEdgeType(f"{field}_missing")) elif field_type in { FieldType.SEQUENCE, FieldType.NONEMPTY_SEQUENCE }: seen_sequence_categories.add( node_spec.sequence_item_types[field]) node_schema.in_edges.append( graph_types.InEdgeType(f"{field}_in")) node_schema.out_edges.append( graph_types.OutEdgeType(f"{field}_out_all")) node_schema.out_edges.append( graph_types.OutEdgeType(f"{field}_out_first")) node_schema.out_edges.append( graph_types.OutEdgeType(f"{field}_out_last")) if field_type == FieldType.SEQUENCE: node_schema.in_edges.append( graph_types.InEdgeType(f"{field}_missing")) elif field_type in {FieldType.NO_CHILDREN, FieldType.IGNORE}: # No edges for these fields. pass else: raise ValueError(f"Unexpected field type {field_type}") result[graph_types.NodeType(node_type)] = node_schema # Build node schemas for each category helper for category in sorted(seen_sequence_categories): helper_type = graph_types.NodeType(f"{category}-seq-helper") assert helper_type not in result node_schema = graph_types.NodeSchema( in_edges=[ graph_types.InEdgeType("parent_in"), graph_types.InEdgeType("item_in"), graph_types.InEdgeType("next_in"), graph_types.InEdgeType("next_missing"), graph_types.InEdgeType("prev_in"), graph_types.InEdgeType("prev_missing") ], out_edges=[ graph_types.OutEdgeType("parent_out"), graph_types.OutEdgeType("item_out"), graph_types.OutEdgeType("next_out"), graph_types.OutEdgeType("prev_out") ]) result[helper_type] = node_schema return result
def build_loop_graph(self): """Helper method to build this complex graph, to test graph encodings. ┌───────<──┐ ┌───────<─────────<─────┐ │ │ │ │ │ [ao_0]│ ↓[ai_0] │ │ (a0) │ │ [ai_1]↑ │[ao_1] │ │ │ │ │ │ [ao_0]│ ↓[ai_0] │ ↓ (a1) ↑ │ [ai_1]↑ │[ao_1] │ │ │ │ │ │ [bo_0]│ ↓[bi_0] │ │ ╭──╮───────>[bo_2]────┐ │ │ │b0│───────>[bo_2]──┐ │ │ │ │ │<──[bi_0]─────<─┘ │ │ │ ╰──╯<──[bi_0]─────<─┐ │ │ │ [bi_0]↑ │[bo_1] │ │ ↑ │ │ │ │ │ │ │ [bo_0]│ ↓[bi_0] │ │ │ │ ╭──╮───────>[bo_2]──┘ │ │ │ │b1│───────>[bo_2]──┐ │ │ ↓ │ │<──[bi_0]─────<─┘ │ │ │ ╰──╯<──[bi_0]─────<───┘ │ │ [bi_0]↑ │[bo_1] │ │ │ │ │ └───────>──┘ └───────>─────────>─────┘ Returns: Tuple (schema, graph) for the above structure. """ a = graph_types.NodeType("a") b = graph_types.NodeType("b") ai_0 = graph_types.InEdgeType("ai_0") ai_1 = graph_types.InEdgeType("ai_1") bi_0 = graph_types.InEdgeType("bi_0") bi_0 = graph_types.InEdgeType("bi_0") ao_0 = graph_types.OutEdgeType("ao_0") ao_1 = graph_types.OutEdgeType("ao_1") bo_0 = graph_types.OutEdgeType("bo_0") bo_1 = graph_types.OutEdgeType("bo_1") bo_2 = graph_types.OutEdgeType("bo_2") a0 = graph_types.NodeId("a0") a1 = graph_types.NodeId("a1") b0 = graph_types.NodeId("b0") b1 = graph_types.NodeId("b1") schema = { a: graph_types.NodeSchema(in_edges=[ai_0, ai_1], out_edges=[ao_0, ao_1]), b: graph_types.NodeSchema(in_edges=[bi_0], out_edges=[bo_0, bo_1, bo_2]), } test_graph = { a0: graph_types.GraphNode( a, { ao_0: [graph_types.InputTaggedNode(b1, bi_0)], ao_1: [graph_types.InputTaggedNode(a1, ai_0)] }), a1: graph_types.GraphNode( a, { ao_0: [graph_types.InputTaggedNode(a0, ai_1)], ao_1: [graph_types.InputTaggedNode(b0, bi_0)] }), b0: graph_types.GraphNode( b, { bo_0: [graph_types.InputTaggedNode(a1, ai_1)], bo_1: [graph_types.InputTaggedNode(b1, bi_0)], bo_2: [ graph_types.InputTaggedNode(b0, bi_0), graph_types.InputTaggedNode(b1, bi_0) ] }), b1: graph_types.GraphNode( b, { bo_0: [graph_types.InputTaggedNode(b0, bi_0)], bo_1: [graph_types.InputTaggedNode(a0, ai_0)], bo_2: [ graph_types.InputTaggedNode(b0, bi_0), graph_types.InputTaggedNode(b1, bi_0) ] }), } return schema, test_graph
def test_encode_maze(self): """Tests that an encoded maze is correct and matches the schema.""" maze = np.array([ [1, 1, 1, 1, 1], [1, 0, 1, 1, 1], [1, 1, 1, 1, 1], ]).astype(bool) encoded_graph, coordinates = maze_schema.encode_maze(maze) # Check coordinates. expected_coords = [] for r in range(3): for c in range(5): if (r, c) != (1, 1): expected_coords.append((r, c)) self.assertEqual(coordinates, expected_coords) # Check a few nodes. self.assertEqual( encoded_graph[graph_types.NodeId("cell_0_0")], graph_types.GraphNode( graph_types.NodeType("cell_xRxD"), { graph_types.OutEdgeType("R_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_0_1"), graph_types.InEdgeType("L_in")) ], graph_types.OutEdgeType("D_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_1_0"), graph_types.InEdgeType("U_in")) ], })) self.assertEqual( encoded_graph[graph_types.NodeId("cell_1_4")], graph_types.GraphNode( graph_types.NodeType("cell_LxUD"), { graph_types.OutEdgeType("L_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_1_3"), graph_types.InEdgeType("R_in")) ], graph_types.OutEdgeType("U_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_0_4"), graph_types.InEdgeType("D_in")) ], graph_types.OutEdgeType("D_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_2_4"), graph_types.InEdgeType("U_in")) ], })) # Check schema validity. schema_util.assert_conforms_to_schema(encoded_graph, maze_schema.build_maze_schema(2))