def test_ast_graph_optional_field_edges(self): """Test that edges for optional fields are correct.""" root = gast.parse("return 1\nreturn") graph, _ = py_ast_graphs.py_ast_to_graph(root) self.assertEqual( graph["root_body_0_item__Return"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0_item_value__Constant"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual( graph["root_body_0_item_value__Constant"].out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Return"), in_edge=graph_types.InEdgeType("value_in")) ]) self.assertEqual( graph["root_body_1_item__Return"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_1_item__Return"), in_edge=graph_types.InEdgeType("value_missing")) ])
def build_doubly_linked_list_graph(self, length): """Helper method to build a doubly-linked-list graph and schema.""" schema = { graph_types.NodeType("node"): graph_types.NodeSchema(in_edges=[ graph_types.InEdgeType("next_in"), graph_types.InEdgeType("prev_in"), ], out_edges=[ graph_types.OutEdgeType("next_out"), graph_types.OutEdgeType("prev_out"), ]) } graph = {} for i in range(length): graph[graph_types.NodeId(str(i))] = graph_types.GraphNode( graph_types.NodeType("node"), { graph_types.OutEdgeType("next_out"): [ graph_types.InputTaggedNode( node_id=graph_types.NodeId(str((i + 1) % length)), in_edge=graph_types.InEdgeType("prev_in")) ], graph_types.OutEdgeType("prev_out"): [ graph_types.InputTaggedNode( node_id=graph_types.NodeId(str((i - 1) % length)), in_edge=graph_types.InEdgeType("next_in")) ] }) return schema, graph
def test_all_input_tagged_nodes(self): # (note: python3 dicts maintain order, so B2 comes before B1) graph = { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B1"), graph_types.InEdgeType("bi_1")), graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ] }), graph_types.NodeId("B2"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B1"), graph_types.InEdgeType("bi_0")) ] }), graph_types.NodeId("B1"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B2"), graph_types.InEdgeType("bi_0")) ] }), } expected_itns = [ graph_types.InputTaggedNode(graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")), graph_types.InputTaggedNode(graph_types.NodeId("B2"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode(graph_types.NodeId("B1"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode(graph_types.NodeId("B1"), graph_types.InEdgeType("bi_1")), ] actual_itns = schema_util.all_input_tagged_nodes(graph) self.assertEqual(actual_itns, expected_itns)
def test_schema_edges(self): mini_schema = { "foo": graph_types.NodeSchema(in_edges=["ignored"], out_edges=["a", "b"]), "bar": graph_types.NodeSchema(in_edges=["ignored"], out_edges=["b", "c"]), } mini_graph = { "foo_node": graph_types.GraphNode( "foo", { "a": [graph_types.InputTaggedNode("bar_node", "ignored")], "b": [graph_types.InputTaggedNode("foo_node", "ignored")] }), "bar_node": graph_types.GraphNode( "bar", { "b": [graph_types.InputTaggedNode("bar_node", "ignored")], "c": [graph_types.InputTaggedNode("foo_node", "ignored")] }), } schema_edge_types = graph_edge_util.schema_edge_types( mini_schema, with_node_types=False) self.assertEqual(schema_edge_types, {"SCHEMA_a", "SCHEMA_b", "SCHEMA_c"}) schema_edges = graph_edge_util.compute_schema_edges( mini_graph, with_node_types=False) self.assertEqual(schema_edges, [ ("foo_node", "bar_node", "SCHEMA_a"), ("foo_node", "foo_node", "SCHEMA_b"), ("bar_node", "bar_node", "SCHEMA_b"), ("bar_node", "foo_node", "SCHEMA_c"), ]) schema_edge_types = graph_edge_util.schema_edge_types( mini_schema, with_node_types=True) self.assertEqual( schema_edge_types, { "SCHEMA_a_FROM_foo", "SCHEMA_b_FROM_foo", "SCHEMA_b_FROM_bar", "SCHEMA_c_FROM_bar" }) schema_edges = graph_edge_util.compute_schema_edges( mini_graph, with_node_types=True) self.assertEqual(schema_edges, [ ("foo_node", "bar_node", "SCHEMA_a_FROM_foo"), ("foo_node", "foo_node", "SCHEMA_b_FROM_foo"), ("bar_node", "bar_node", "SCHEMA_b_FROM_bar"), ("bar_node", "foo_node", "SCHEMA_c_FROM_bar"), ])
def test_conforms_to_schema(self): test_schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[ graph_types.InEdgeType("ai_0"), graph_types.InEdgeType("ai_1") ], out_edges=[graph_types.OutEdgeType("ao_0") ]), graph_types.NodeType("b"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")], out_edges=[ graph_types.OutEdgeType("bo_0"), graph_types.OutEdgeType("bo_1") ]), } # Valid graph test_graph = { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ] }), graph_types.NodeId("B"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")) ] }) } schema_util.assert_conforms_to_schema(test_graph, test_schema)
def connect(from_id, out_type, in_type, to_id): """Adds directed edges to the graph.""" out_type = graph_types.OutEdgeType(out_type) in_type = graph_types.InEdgeType(in_type) if out_type not in result[from_id].out_edges: result[from_id].out_edges[out_type] = [] result[from_id].out_edges[out_type].append( graph_types.InputTaggedNode(node_id=to_id, in_edge=in_type))
def test_ast_graph_unique_field_edges(self): """Test that edges for unique fields are correct.""" root = gast.parse("print(1)") graph, _ = py_ast_graphs.py_ast_to_graph(root) self.assertEqual(graph["root_body_0_item__Expr"].out_edges["value_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item_value__Call"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual( graph["root_body_0_item_value__Call"].out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Expr"), in_edge=graph_types.InEdgeType("value_in")) ])
def encode_maze(maze): """Encode a boolean mask array as a graph. We assume a [row, col]-indexed coordinate system, oriented so that going right corresponds to changing indexes as (0, +1), and going down corresponds to changing indies as (+1, 0). Args: maze: Maze, as a boolean array <bool[width, height]>, where True corresponds to cells that should be nodes in the graph. Returns: Encoded graph, along with a list of the coordinate indices of each cell in the graph. """ # (direction, reverse direction, (dr, dc)) shifts = [ ("L", "R", (0, -1)), ("R", "L", (0, 1)), ("U", "D", (-1, 0)), ("D", "U", (1, 0)), ] graph = {} idx_to_cell = [] for i in range(maze.shape[0]): for j in range(maze.shape[1]): if maze[i, j]: typename = "cell_" out_edges = {} for name, in_name, (di, dj) in shifts: ni = i + di nj = j + dj if (0 <= ni < maze.shape[0] and 0 <= nj < maze.shape[1] and maze[ni, nj]): typename = typename + name out_edges[graph_types.OutEdgeType(f"{name}_out")] = [ graph_types.InputTaggedNode( graph_types.NodeId(f"cell_{ni}_{nj}"), graph_types.InEdgeType(f"{in_name}_in")) ] else: typename = typename + "x" graph[graph_types.NodeId( f"cell_{i}_{j}")] = graph_types.GraphNode( graph_types.NodeType(typename), out_edges) idx_to_cell.append((i, j)) return graph, idx_to_cell
def test_ast_graph_sequence_field_edges(self): """Test that edges for sequence fields are correct. Note that sequence fields produce connections between three nodes: the parent, the helper node, and the child. """ root = gast.parse( textwrap.dedent("""\ print(1) print(2) print(3) print(4) print(5) print(6) """)) graph, _ = py_ast_graphs.py_ast_to_graph(root) # Child edges from the parent node node = graph["root__Module"] self.assertLen(node.out_edges["body_out_all"], 6) self.assertEqual(node.out_edges["body_out_first"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual(node.out_edges["body_out_last"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_5__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("parent_in")) ]) # Edges from the sequence helper node = graph["root_body_0__Module_body-seq-helper"] self.assertEqual(node.out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root__Module"), in_edge=graph_types.InEdgeType("body_in")) ]) self.assertEqual(node.out_edges["item_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId("root_body_0_item__Expr"), in_edge=graph_types.InEdgeType("parent_in")) ]) self.assertEqual(node.out_edges["prev_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("prev_missing")) ]) self.assertEqual(node.out_edges["next_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_1__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("prev_in")) ]) # Parent edge of the item node = graph["root_body_0_item__Expr"] self.assertEqual(node.out_edges["parent_out"], [ graph_types.InputTaggedNode( node_id=graph_types.NodeId( "root_body_0__Module_body-seq-helper"), in_edge=graph_types.InEdgeType("item_in")) ])
class SchemaUtilTest(parameterized.TestCase): def test_conforms_to_schema(self): test_schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[ graph_types.InEdgeType("ai_0"), graph_types.InEdgeType("ai_1") ], out_edges=[graph_types.OutEdgeType("ao_0") ]), graph_types.NodeType("b"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")], out_edges=[ graph_types.OutEdgeType("bo_0"), graph_types.OutEdgeType("bo_1") ]), } # Valid graph test_graph = { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ] }), graph_types.NodeId("B"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")) ] }) } schema_util.assert_conforms_to_schema(test_graph, test_schema) @parameterized.named_parameters( { "testcase_name": "missing_node", "graph": { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B"), graph_types.InEdgeType("bi_0")) ] }) }, "expected_error": "Node A has connection to missing node B" }, { "testcase_name": "bad_node_type", "graph": { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("z"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_0")) ] }) }, "expected_error": "Node A's type z not in schema" }, { "testcase_name": "missing_out_edge", "graph": { graph_types.NodeId("A"): graph_types.GraphNode(graph_types.NodeType("a"), {graph_types.OutEdgeType("ao_0"): []}) }, "expected_error": "Node A missing out edge of type ao_0" }, { "testcase_name": "bad_out_edge_type", "graph": { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_0")) ], "foo": [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_0")) ] }) }, "expected_error": "Node A has out-edges of invalid type foo" }, { "testcase_name": "bad_in_edge_type", "graph": { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("bar")) ], }) }, "expected_error": "Node A has in-edges of invalid type bar" }) def test_does_not_conform_to_schema(self, graph, expected_error): test_schema = { graph_types.NodeType("a"): graph_types.NodeSchema(in_edges=[ graph_types.InEdgeType("ai_0"), graph_types.InEdgeType("ai_1") ], out_edges=[graph_types.OutEdgeType("ao_0") ]), graph_types.NodeType("b"): graph_types.NodeSchema(in_edges=[graph_types.InEdgeType("bi_0")], out_edges=[ graph_types.OutEdgeType("bo_0"), graph_types.OutEdgeType("bo_1") ]), } # pylint: disable=g-error-prone-assert-raises with self.assertRaisesRegex(ValueError, expected_error): schema_util.assert_conforms_to_schema(graph, test_schema) # pylint: enable=g-error-prone-assert-raises def test_all_input_tagged_nodes(self): # (note: python3 dicts maintain order, so B2 comes before B1) graph = { graph_types.NodeId("A"): graph_types.GraphNode( graph_types.NodeType("a"), { graph_types.OutEdgeType("ao_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("B1"), graph_types.InEdgeType("bi_1")), graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ] }), graph_types.NodeId("B2"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B1"), graph_types.InEdgeType("bi_0")) ] }), graph_types.NodeId("B1"): graph_types.GraphNode( graph_types.NodeType("b"), { graph_types.OutEdgeType("bo_0"): [ graph_types.InputTaggedNode( graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")) ], graph_types.OutEdgeType("bo_1"): [ graph_types.InputTaggedNode( graph_types.NodeId("B2"), graph_types.InEdgeType("bi_0")) ] }), } expected_itns = [ graph_types.InputTaggedNode(graph_types.NodeId("A"), graph_types.InEdgeType("ai_1")), graph_types.InputTaggedNode(graph_types.NodeId("B2"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode(graph_types.NodeId("B1"), graph_types.InEdgeType("bi_0")), graph_types.InputTaggedNode(graph_types.NodeId("B1"), graph_types.InEdgeType("bi_1")), ] actual_itns = schema_util.all_input_tagged_nodes(graph) self.assertEqual(actual_itns, expected_itns)
def build_loop_graph(self): """Helper method to build this complex graph, to test graph encodings. ┌───────<──┐ ┌───────<─────────<─────┐ │ │ │ │ │ [ao_0]│ ↓[ai_0] │ │ (a0) │ │ [ai_1]↑ │[ao_1] │ │ │ │ │ │ [ao_0]│ ↓[ai_0] │ ↓ (a1) ↑ │ [ai_1]↑ │[ao_1] │ │ │ │ │ │ [bo_0]│ ↓[bi_0] │ │ ╭──╮───────>[bo_2]────┐ │ │ │b0│───────>[bo_2]──┐ │ │ │ │ │<──[bi_0]─────<─┘ │ │ │ ╰──╯<──[bi_0]─────<─┐ │ │ │ [bi_0]↑ │[bo_1] │ │ ↑ │ │ │ │ │ │ │ [bo_0]│ ↓[bi_0] │ │ │ │ ╭──╮───────>[bo_2]──┘ │ │ │ │b1│───────>[bo_2]──┐ │ │ ↓ │ │<──[bi_0]─────<─┘ │ │ │ ╰──╯<──[bi_0]─────<───┘ │ │ [bi_0]↑ │[bo_1] │ │ │ │ │ └───────>──┘ └───────>─────────>─────┘ Returns: Tuple (schema, graph) for the above structure. """ a = graph_types.NodeType("a") b = graph_types.NodeType("b") ai_0 = graph_types.InEdgeType("ai_0") ai_1 = graph_types.InEdgeType("ai_1") bi_0 = graph_types.InEdgeType("bi_0") bi_0 = graph_types.InEdgeType("bi_0") ao_0 = graph_types.OutEdgeType("ao_0") ao_1 = graph_types.OutEdgeType("ao_1") bo_0 = graph_types.OutEdgeType("bo_0") bo_1 = graph_types.OutEdgeType("bo_1") bo_2 = graph_types.OutEdgeType("bo_2") a0 = graph_types.NodeId("a0") a1 = graph_types.NodeId("a1") b0 = graph_types.NodeId("b0") b1 = graph_types.NodeId("b1") schema = { a: graph_types.NodeSchema(in_edges=[ai_0, ai_1], out_edges=[ao_0, ao_1]), b: graph_types.NodeSchema(in_edges=[bi_0], out_edges=[bo_0, bo_1, bo_2]), } test_graph = { a0: graph_types.GraphNode( a, { ao_0: [graph_types.InputTaggedNode(b1, bi_0)], ao_1: [graph_types.InputTaggedNode(a1, ai_0)] }), a1: graph_types.GraphNode( a, { ao_0: [graph_types.InputTaggedNode(a0, ai_1)], ao_1: [graph_types.InputTaggedNode(b0, bi_0)] }), b0: graph_types.GraphNode( b, { bo_0: [graph_types.InputTaggedNode(a1, ai_1)], bo_1: [graph_types.InputTaggedNode(b1, bi_0)], bo_2: [ graph_types.InputTaggedNode(b0, bi_0), graph_types.InputTaggedNode(b1, bi_0) ] }), b1: graph_types.GraphNode( b, { bo_0: [graph_types.InputTaggedNode(b0, bi_0)], bo_1: [graph_types.InputTaggedNode(a0, ai_0)], bo_2: [ graph_types.InputTaggedNode(b0, bi_0), graph_types.InputTaggedNode(b1, bi_0) ] }), } return schema, test_graph
def test_encode_maze(self): """Tests that an encoded maze is correct and matches the schema.""" maze = np.array([ [1, 1, 1, 1, 1], [1, 0, 1, 1, 1], [1, 1, 1, 1, 1], ]).astype(bool) encoded_graph, coordinates = maze_schema.encode_maze(maze) # Check coordinates. expected_coords = [] for r in range(3): for c in range(5): if (r, c) != (1, 1): expected_coords.append((r, c)) self.assertEqual(coordinates, expected_coords) # Check a few nodes. self.assertEqual( encoded_graph[graph_types.NodeId("cell_0_0")], graph_types.GraphNode( graph_types.NodeType("cell_xRxD"), { graph_types.OutEdgeType("R_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_0_1"), graph_types.InEdgeType("L_in")) ], graph_types.OutEdgeType("D_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_1_0"), graph_types.InEdgeType("U_in")) ], })) self.assertEqual( encoded_graph[graph_types.NodeId("cell_1_4")], graph_types.GraphNode( graph_types.NodeType("cell_LxUD"), { graph_types.OutEdgeType("L_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_1_3"), graph_types.InEdgeType("R_in")) ], graph_types.OutEdgeType("U_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_0_4"), graph_types.InEdgeType("D_in")) ], graph_types.OutEdgeType("D_out"): [ graph_types.InputTaggedNode( graph_types.NodeId("cell_2_4"), graph_types.InEdgeType("U_in")) ], })) # Check schema validity. schema_util.assert_conforms_to_schema(encoded_graph, maze_schema.build_maze_schema(2))