def if_create_main_graph(): sub_graph_2 = build_graph(nodes_attrs=if_sub_graph_2_then_nodes, edges=[ *connect('in_2_int', 'OUT_2'), *connect('ones', 'OUT_2'), *connect('OUT_2', 'OUT_2_out') ], nodes_with_edges_only=True) sub_graph_2_else = build_graph(nodes_attrs=if_sub_graph_2_else_nodes, edges=[ *connect('in_2_int_else', 'OUT_2_else'), *connect('ones_else', 'OUT_2_else'), *connect('OUT_2_else', 'OUT_2_out_else') ], nodes_with_edges_only=True) sub_graph_1 = build_graph(nodes_attrs=if_sub_graph_1_then_nodes, edges=[ *connect('cond_2', '0:If_2'), *connect('IN_2', '1:If_2'), *connect('If_2:0', 'If_2_out'), *connect('in_1_int', 'in_1_int_out') ], nodes_with_edges_only=True) if_node_1 = Node(sub_graph_1, 'If_2') if_node_1.then_graph = sub_graph_2 if_node_1.else_graph = sub_graph_2_else return sub_graph_1
def test_add_output_1(self): sub_graph_1 = if_create_main_graph() if_node_1 = Node(sub_graph_1, 'If_2') sub_graph_1_else = build_graph( nodes_attrs=if_sub_graph_1_else_nodes, edges=[*connect('in_1_int', 'in_1_int_out')], nodes_with_edges_only=True) main_graph = build_graph(nodes_attrs=if_main_graph_nodes, edges=[ *connect('cond', '0:If'), *connect('IN_1', '1:If'), *connect('IN_2', "2:If"), *connect('If:0', 'OUT_1') ], nodes_with_edges_only=True) if_node = Node(main_graph, 'If') if_node.then_graph = sub_graph_1 if_node.else_graph = sub_graph_1_else if_node_out_ports_len = len(if_node.out_ports()) if_2_node_out_ports_len = len(if_node_1.out_ports()) main_graph.graph['additional_outputs'] = ['If', ['If_2', 'in_1_int']] AddOutputRecursive().find_and_replace_pattern(main_graph) if_node = Node(main_graph, 'If') self.assertEqual(len(if_node.out_ports()), if_node_out_ports_len + 1) self.assertEqual( if_node.out_port(1).get_destination().node.op, 'Result') self.assertTrue( np.all( if_node.out_port(1).data.get_shape() == int64_array( [1, 4, 64, 54]))) last_node = Node(sub_graph_1, 'If_2') self.assertEqual(len(last_node.out_ports()), if_2_node_out_ports_len) self.assertEqual( last_node.out_port(0).get_destinations()[1].node.op, 'Result') self.assertTrue( np.all( last_node.out_port(0).data.get_shape() == int64_array( [1, 4, 64, 54])))