def build_doubly_linked_list_graph(self, length): """Helper method to build a doubly-linked-list graph and schema.""" schema = { graph_types.NodeType("node"): graph_types.NodeSchema(in_edges=[ graph_types.InEdgeType("next_in"), graph_types.InEdgeType("prev_in"), ], out_edges=[ graph_types.OutEdgeType("next_out"), graph_types.OutEdgeType("prev_out"), ]) } graph = {} for i in range(length): graph[graph_types.NodeId(str(i))] = graph_types.GraphNode( graph_types.NodeType("node"), { graph_types.OutEdgeType("next_out"): [ graph_types.InputTaggedNode( node_id=graph_types.NodeId(str((i + 1) % length)), in_edge=graph_types.InEdgeType("prev_in")) ], graph_types.OutEdgeType("prev_out"): [ graph_types.InputTaggedNode( node_id=graph_types.NodeId(str((i - 1) % length)), in_edge=graph_types.InEdgeType("next_in")) ] }) return schema, graph
def test_ast_graph_optional_field_edges(self): """Test that edges for optional fields are correct.""" root = gast.parse("return 1\nreturn") graph, _ = py_ast_graphs.py_ast_to_graph(root) self.assertEqual( graph["root_body_0_item__Return"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0_item_value__Constant"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual( graph["root_body_0_item_value__Constant"].out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Return"), in_edge=graph_types.InEdgeType("value_in")) ]) self.assertEqual( graph["root_body_1_item__Return"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_1_item__Return"), in_edge=graph_types.InEdgeType("value_missing")) ])
def encode_maze(maze): """Encode a boolean mask array as a graph. We assume a [row, col]-indexed coordinate system, oriented so that going right corresponds to changing indexes as (0, +1), and going down corresponds to changing indies as (+1, 0). Args: maze: Maze, as a boolean array <bool[width, height]>, where True corresponds to cells that should be nodes in the graph. Returns: Encoded graph, along with a list of the coordinate indices of each cell in the graph. """ # (direction, reverse direction, (dr, dc)) shifts = [ ("L", "R", (0, -1)), ("R", "L", (0, 1)), ("U", "D", (-1, 0)), ("D", "U", (1, 0)), ] graph = {} idx_to_cell = [] for i in range(maze.shape[0]): for j in range(maze.shape[1]): if maze[i, j]: typename = "cell_" out_edges = {} for name, in_name, (di, dj) in shifts: ni = i + di nj = j + dj if (0 <= ni < maze.shape[0] and 0 <= nj < maze.shape[1] and maze[ni, nj]): typename = typename + name out_edges[graph_types.OutEdgeType(f"{name}_out")] = [ graph_types.InputTaggedNode( graph_types.NodeId(f"cell_{ni}_{nj}"), graph_types.InEdgeType(f"{in_name}_in")) ] else: typename = typename + "x" graph[graph_types.NodeId( f"cell_{i}_{j}")] = graph_types.GraphNode( graph_types.NodeType(typename), out_edges) idx_to_cell.append((i, j)) return graph, idx_to_cell
def test_ast_graph_unique_field_edges(self): """Test that edges for unique fields are correct.""" root = gast.parse("print(1)") graph, _ = py_ast_graphs.py_ast_to_graph(root) self.assertEqual(graph["root_body_0_item__Expr"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item_value__Call"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual( graph["root_body_0_item_value__Call"].out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Expr"), in_edge=graph_types.InEdgeType("value_in")) ])
def get_graph_node_id( type_name, path_from_root, ): """Builds an ID for a graph node, and creates the node if necessary.""" id_parts = ["root"] for part in path_from_root: id_parts.extend(("_", str(part))) id_parts.extend(("__", type_name)) node_id = graph_types.NodeId("".join(id_parts)) if node_id not in result: result[node_id] = graph_types.GraphNode( node_type=graph_types.NodeType(type_name), out_edges={}) return node_id
def test_conforms_to_schema(self): test_schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[ graph_types.InEdgeType("ai_0"), graph_types.InEdgeType("ai_1") ], out_edges=[graph_types.OutEdgeType("ao_0") ]), graph_types.NodeType("b"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")], out_edges=[ graph_types.OutEdgeType("bo_0"), graph_types.OutEdgeType("bo_1") ]), } # Valid graph test_graph = { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ] }), graph_types.NodeId("B"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")) ] }) } schema_util.assert_conforms_to_schema(test_graph, test_schema)
def test_ast_graph_sequence_field_edges(self): """Test that edges for sequence fields are correct. Note that sequence fields produce connections between three nodes: the parent, the helper node, and the child. """ root = gast.parse( textwrap.dedent("""\ print(1) print(2) print(3) print(4) print(5) print(6) """)) graph, _ = py_ast_graphs.py_ast_to_graph(root) # Child edges from the parent node node = graph["root__Module"] self.assertLen(node.out_edges["body_out_all"], 6) self.assertEqual(node.out_edges["body_out_first"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual(node.out_edges["body_out_last"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_5__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("parent_in")) ]) # Edges from the sequence helper node = graph["root_body_0__Module_body-seq-helper"] self.assertEqual(node.out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root__Module"), in_edge=graph_types.InEdgeType("body_in")) ]) self.assertEqual(node.out_edges["item_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Expr"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual(node.out_edges["prev_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("prev_missing")) ]) self.assertEqual(node.out_edges["next_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_1__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("prev_in")) ]) # Parent edge of the item node = graph["root_body_0_item__Expr"] self.assertEqual(node.out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("item_in")) ])
class SchemaUtilTest(parameterized.TestCase): def test_conforms_to_schema(self): test_schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[ graph_types.InEdgeType("ai_0"), graph_types.InEdgeType("ai_1") ], out_edges=[graph_types.OutEdgeType("ao_0") ]), graph_types.NodeType("b"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")], out_edges=[ graph_types.OutEdgeType("bo_0"), graph_types.OutEdgeType("bo_1") ]), } # Valid graph test_graph = { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ] }), graph_types.NodeId("B"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")) ] }) } schema_util.assert_conforms_to_schema(test_graph, test_schema) @parameterized.named_parameters( { "testcase_name": "missing_node", "graph": { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")) ] }) }, "expected_error": "Node A has connection to missing node B" }, { "testcase_name": "bad_node_type", "graph": { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("z"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_0")) ] }) }, "expected_error": "Node A's type z not in schema" }, { "testcase_name": "missing_out_edge", "graph": { graph_types.NodeId("A"): graph_types.GraphNode(graph_types.NodeType("a"), {graph_types.OutEdgeType("ao_0"): []}) }, "expected_error": "Node A missing out edge of type ao_0" }, { "testcase_name": "bad_out_edge_type", "graph": { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_0")) ], "foo": [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_0")) ] }) }, "expected_error": "Node A has out-edges of invalid type foo" }, { "testcase_name": "bad_in_edge_type", "graph": { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("bar")) ], }) }, "expected_error": "Node A has in-edges of invalid type bar" }) def test_does_not_conform_to_schema(self, graph, expected_error): test_schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[ graph_types.InEdgeType("ai_0"), graph_types.InEdgeType("ai_1") ], out_edges=[graph_types.OutEdgeType("ao_0") ]), graph_types.NodeType("b"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")], out_edges=[ graph_types.OutEdgeType("bo_0"), graph_types.OutEdgeType("bo_1") ]), } # pylint: disable=g-error-prone-assert-raises with self.assertRaisesRegex(ValueError, expected_error): schema_util.assert_conforms_to_schema(graph, test_schema) # pylint: enable=g-error-prone-assert-raises def test_all_input_tagged_nodes(self): # (note: python3 dicts maintain order, so B2 comes before B1) graph = { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B1"), graph_types.InEdgeType("bi_1")), graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ] }), graph_types.NodeId("B2"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B1"), graph_types.InEdgeType("bi_0")) ] }), graph_types.NodeId("B1"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B2"), graph_types.InEdgeType("bi_0")) ] }), } expected_itns = [ graph_types.InputTaggedNode(graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")), graph_types.InputTaggedNode(graph_types.NodeId("B2"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode(graph_types.NodeId("B1"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode(graph_types.NodeId("B1"), graph_types.InEdgeType("bi_1")), ] actual_itns = schema_util.all_input_tagged_nodes(graph) self.assertEqual(actual_itns, expected_itns)
def test_all_input_tagged_nodes(self): # (note: python3 dicts maintain order, so B2 comes before B1) graph = { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B1"), graph_types.InEdgeType("bi_1")), graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ] }), graph_types.NodeId("B2"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B1"), graph_types.InEdgeType("bi_0")) ] }), graph_types.NodeId("B1"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B2"), graph_types.InEdgeType("bi_0")) ] }), } expected_itns = [ graph_types.InputTaggedNode(graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")), graph_types.InputTaggedNode(graph_types.NodeId("B2"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode(graph_types.NodeId("B1"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode(graph_types.NodeId("B1"), graph_types.InEdgeType("bi_1")), ] actual_itns = schema_util.all_input_tagged_nodes(graph) self.assertEqual(actual_itns, expected_itns)
def build_loop_graph(self): """Helper method to build this complex graph, to test graph encodings. ┌───────<──┐ ┌───────<─────────<─────┐ │ │ │ │ │ [ao_0]│ ↓[ai_0] │ │ (a0) │ │ [ai_1]↑ │[ao_1] │ │ │ │ │ │ [ao_0]│ ↓[ai_0] │ ↓ (a1) ↑ │ [ai_1]↑ │[ao_1] │ │ │ │ │ │ [bo_0]│ ↓[bi_0] │ │ ╭──╮───────>[bo_2]────┐ │ │ │b0│───────>[bo_2]──┐ │ │ │ │ │<──[bi_0]─────<─┘ │ │ │ ╰──╯<──[bi_0]─────<─┐ │ │ │ [bi_0]↑ │[bo_1] │ │ ↑ │ │ │ │ │ │ │ [bo_0]│ ↓[bi_0] │ │ │ │ ╭──╮───────>[bo_2]──┘ │ │ │ │b1│───────>[bo_2]──┐ │ │ ↓ │ │<──[bi_0]─────<─┘ │ │ │ ╰──╯<──[bi_0]─────<───┘ │ │ [bi_0]↑ │[bo_1] │ │ │ │ │ └───────>──┘ └───────>─────────>─────┘ Returns: Tuple (schema, graph) for the above structure. """ a = graph_types.NodeType("a") b = graph_types.NodeType("b") ai_0 = graph_types.InEdgeType("ai_0") ai_1 = graph_types.InEdgeType("ai_1") bi_0 = graph_types.InEdgeType("bi_0") bi_0 = graph_types.InEdgeType("bi_0") ao_0 = graph_types.OutEdgeType("ao_0") ao_1 = graph_types.OutEdgeType("ao_1") bo_0 = graph_types.OutEdgeType("bo_0") bo_1 = graph_types.OutEdgeType("bo_1") bo_2 = graph_types.OutEdgeType("bo_2") a0 = graph_types.NodeId("a0") a1 = graph_types.NodeId("a1") b0 = graph_types.NodeId("b0") b1 = graph_types.NodeId("b1") schema = { a: graph_types.NodeSchema(in_edges=[ai_0, ai_1], out_edges=[ao_0, ao_1]), b: graph_types.NodeSchema(in_edges=[bi_0], out_edges=[bo_0, bo_1, bo_2]), } test_graph = { a0: graph_types.GraphNode( a, { ao_0: [graph_types.InputTaggedNode(b1, bi_0)], ao_1: [graph_types.InputTaggedNode(a1, ai_0)] }), a1: graph_types.GraphNode( a, { ao_0: [graph_types.InputTaggedNode(a0, ai_1)], ao_1: [graph_types.InputTaggedNode(b0, bi_0)] }), b0: graph_types.GraphNode( b, { bo_0: [graph_types.InputTaggedNode(a1, ai_1)], bo_1: [graph_types.InputTaggedNode(b1, bi_0)], bo_2: [ graph_types.InputTaggedNode(b0, bi_0), graph_types.InputTaggedNode(b1, bi_0) ] }), b1: graph_types.GraphNode( b, { bo_0: [graph_types.InputTaggedNode(b0, bi_0)], bo_1: [graph_types.InputTaggedNode(a0, ai_0)], bo_2: [ graph_types.InputTaggedNode(b0, bi_0), graph_types.InputTaggedNode(b1, bi_0) ] }), } return schema, test_graph
def test_encode_maze(self): """Tests that an encoded maze is correct and matches the schema.""" maze = np.array([ [1, 1, 1, 1, 1], [1, 0, 1, 1, 1], [1, 1, 1, 1, 1], ]).astype(bool) encoded_graph, coordinates = maze_schema.encode_maze(maze) # Check coordinates. expected_coords = [] for r in range(3): for c in range(5): if (r, c) != (1, 1): expected_coords.append((r, c)) self.assertEqual(coordinates, expected_coords) # Check a few nodes. self.assertEqual( encoded_graph[graph_types.NodeId("cell_0_0")], graph_types.GraphNode( graph_types.NodeType("cell_xRxD"), { graph_types.OutEdgeType("R_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_0_1"), graph_types.InEdgeType("L_in")) ], graph_types.OutEdgeType("D_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_1_0"), graph_types.InEdgeType("U_in")) ], })) self.assertEqual( encoded_graph[graph_types.NodeId("cell_1_4")], graph_types.GraphNode( graph_types.NodeType("cell_LxUD"), { graph_types.OutEdgeType("L_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_1_3"), graph_types.InEdgeType("R_in")) ], graph_types.OutEdgeType("U_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_0_4"), graph_types.InEdgeType("D_in")) ], graph_types.OutEdgeType("D_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_2_4"), graph_types.InEdgeType("U_in")) ], })) # Check schema validity. schema_util.assert_conforms_to_schema(encoded_graph, maze_schema.build_maze_schema(2))