def create_graph(): sub_graph_2 = build_graph(nodes_attrs=ti_sub_graph_2_nodes, edges=[ *connect('cond_2_int', 'cond_2_int_out'), *connect('in_2_int', 'OUT_2'), *connect('ones', 'OUT_2'), *connect('OUT_2', 'OUT_2_out'), *connect('in_2_int', 'in_2_int_out') ], nodes_with_edges_only=True) sub_graph_1 = build_graph(nodes_attrs=ti_sub_graph_1_nodes, edges=[ *connect('cond_2', '1:Loop_2'), *connect('IN_2', '0:Loop_2'), *connect('Loop_2:0', 'Loop_2_out'), *connect('in_1_int', 'in_1_int_out'), *connect('cond_1_int', 'cond_1_int_out') ], nodes_with_edges_only=True) loop_node_1 = Node(sub_graph_1, 'Loop_2') loop_node_1.body = sub_graph_2 loop_node_1.in_edge(0)['external_port_id'] = 0 loop_node_1.in_edge(1)['external_port_id'] = 1 loop_node_1.out_edge(0)['external_port_id'] = 2 main_graph = ti_create_main_graph(sub_graph_1) main_graph.graph['additional_outputs'] = ['Loop', 'Loop_2'] return main_graph, sub_graph_1
def ti_create_main_graph(body): main_graph = build_graph(nodes_attrs=ti_main_graph_nodes, edges=[ *connect('M', '0:Loop'), *connect('cond', '1:Loop'), *connect('IN_2', '2:Loop'), *connect('IN_1', "3:Loop"), *connect('Loop:0', 'OUT_1') ], nodes_with_edges_only=True) loop_node = Node(main_graph, 'Loop') loop_node.body = body loop_node.in_edge(0)['external_port_id'] = 0 loop_node.in_edge(1)['external_port_id'] = 1 loop_node.in_edge(2)['external_port_id'] = 2 loop_node.in_edge(3)['external_port_id'] = 3 loop_node.out_edge(0)['external_port_id'] = 4 return main_graph
def test_add_output_1(self): sub_graph_2 = build_graph(nodes_attrs=sub_graph_2_nodes, edges=[ *connect('cond_2_int', 'cond_2_int_out'), *connect('in_2_int', 'OUT_2'), *connect('ones', 'OUT_2'), *connect('OUT_2', 'OUT_2_out'), *connect('in_2_int', 'in_2_int_out') ], nodes_with_edges_only=True) sub_graph_1 = build_graph(nodes_attrs=sub_graph_1_nodes, edges=[ *connect('M_2', '0:Loop_2'), *connect('cond_2', '1:Loop_2'), *connect('IN_2', '2:Loop_2'), *connect('Loop_2:0', 'Loop_2_out'), *connect('in_1_int', 'in_1_int_out'), *connect('cond_1_int', 'cond_1_int_out') ], nodes_with_edges_only=True) loop_node_1 = Node(sub_graph_1, 'Loop_2') loop_node_1.body = sub_graph_2 main_graph = build_graph(nodes_attrs=main_graph_nodes, edges=[ *connect('M', '0:Loop'), *connect('cond', '1:Loop'), *connect('IN_2', '2:Loop'), *connect('IN_1', "3:Loop"), *connect('Loop:0', 'OUT_1') ], nodes_with_edges_only=True) loop_node = Node(main_graph, 'Loop') loop_node.body = sub_graph_1 main_graph.graph['additional_outputs'] = ['Loop', 'Loop_2'] loop_node_output_port_map_len = len(loop_node.output_port_map) loop_node_out_ports_len = len(loop_node.out_ports()) loop_2_out_ports_len = len(loop_node_1.out_ports()) max_layer_id = 5 AddOutputRecursive().find_and_replace_pattern(main_graph) loop_node = Node(main_graph, 'Loop') self.assertEqual(len(loop_node.output_port_map), loop_node_output_port_map_len + 1) self.assertEqual(len(loop_node.out_ports()), loop_node_out_ports_len + 1) self.assertEqual( loop_node.out_port(1).get_destination().node.op, 'Result') self.assertTrue( np.all( loop_node.out_port(1).data.get_shape() == int64_array( [5, 10, 4, 64, 54]))) last_node = Node(sub_graph_1, 'Loop_2') self.assertEqual(len(last_node.out_ports()), loop_2_out_ports_len) unsq_node = last_node.out_port(0).get_destinations()[1].node self.assertEqual(unsq_node.op, 'Unsqueeze') self.assertEqual( unsq_node.out_port(0).get_destination().node.op, 'Result') self.assertEqual( unsq_node.out_port(0).get_destination().node.internal_layer_id, max_layer_id + 3) self.assertTrue( np.all( unsq_node.out_port(0).data.get_shape() == int64_array( [1, 10, 4, 64, 54])))