コード例 #1
0
    def create_graph():
        sub_graph_2 = build_graph(nodes_attrs=ti_sub_graph_2_nodes,
                                  edges=[
                                      *connect('cond_2_int', 'cond_2_int_out'),
                                      *connect('in_2_int', 'OUT_2'),
                                      *connect('ones', 'OUT_2'),
                                      *connect('OUT_2', 'OUT_2_out'),
                                      *connect('in_2_int', 'in_2_int_out')
                                  ],
                                  nodes_with_edges_only=True)

        sub_graph_1 = build_graph(nodes_attrs=ti_sub_graph_1_nodes,
                                  edges=[
                                      *connect('cond_2', '1:Loop_2'),
                                      *connect('IN_2', '0:Loop_2'),
                                      *connect('Loop_2:0', 'Loop_2_out'),
                                      *connect('in_1_int', 'in_1_int_out'),
                                      *connect('cond_1_int', 'cond_1_int_out')
                                  ],
                                  nodes_with_edges_only=True)
        loop_node_1 = Node(sub_graph_1, 'Loop_2')
        loop_node_1.body = sub_graph_2
        loop_node_1.in_edge(0)['external_port_id'] = 0
        loop_node_1.in_edge(1)['external_port_id'] = 1
        loop_node_1.out_edge(0)['external_port_id'] = 2

        main_graph = ti_create_main_graph(sub_graph_1)
        main_graph.graph['additional_outputs'] = ['Loop', 'Loop_2']

        return main_graph, sub_graph_1
コード例 #2
0
def ti_create_main_graph(body):
    main_graph = build_graph(nodes_attrs=ti_main_graph_nodes,
                             edges=[
                                 *connect('M', '0:Loop'),
                                 *connect('cond', '1:Loop'),
                                 *connect('IN_2', '2:Loop'),
                                 *connect('IN_1', "3:Loop"),
                                 *connect('Loop:0', 'OUT_1')
                             ],
                             nodes_with_edges_only=True)
    loop_node = Node(main_graph, 'Loop')
    loop_node.body = body
    loop_node.in_edge(0)['external_port_id'] = 0
    loop_node.in_edge(1)['external_port_id'] = 1
    loop_node.in_edge(2)['external_port_id'] = 2
    loop_node.in_edge(3)['external_port_id'] = 3
    loop_node.out_edge(0)['external_port_id'] = 4

    return main_graph
コード例 #3
0
    def test_add_output_1(self):
        sub_graph_2 = build_graph(nodes_attrs=sub_graph_2_nodes,
                                  edges=[
                                      *connect('cond_2_int', 'cond_2_int_out'),
                                      *connect('in_2_int', 'OUT_2'),
                                      *connect('ones', 'OUT_2'),
                                      *connect('OUT_2', 'OUT_2_out'),
                                      *connect('in_2_int', 'in_2_int_out')
                                  ],
                                  nodes_with_edges_only=True)

        sub_graph_1 = build_graph(nodes_attrs=sub_graph_1_nodes,
                                  edges=[
                                      *connect('M_2', '0:Loop_2'),
                                      *connect('cond_2', '1:Loop_2'),
                                      *connect('IN_2', '2:Loop_2'),
                                      *connect('Loop_2:0', 'Loop_2_out'),
                                      *connect('in_1_int', 'in_1_int_out'),
                                      *connect('cond_1_int', 'cond_1_int_out')
                                  ],
                                  nodes_with_edges_only=True)
        loop_node_1 = Node(sub_graph_1, 'Loop_2')
        loop_node_1.body = sub_graph_2

        main_graph = build_graph(nodes_attrs=main_graph_nodes,
                                 edges=[
                                     *connect('M', '0:Loop'),
                                     *connect('cond', '1:Loop'),
                                     *connect('IN_2', '2:Loop'),
                                     *connect('IN_1', "3:Loop"),
                                     *connect('Loop:0', 'OUT_1')
                                 ],
                                 nodes_with_edges_only=True)
        loop_node = Node(main_graph, 'Loop')
        loop_node.body = sub_graph_1
        main_graph.graph['additional_outputs'] = ['Loop', 'Loop_2']
        loop_node_output_port_map_len = len(loop_node.output_port_map)
        loop_node_out_ports_len = len(loop_node.out_ports())
        loop_2_out_ports_len = len(loop_node_1.out_ports())
        max_layer_id = 5

        AddOutputRecursive().find_and_replace_pattern(main_graph)

        loop_node = Node(main_graph, 'Loop')
        self.assertEqual(len(loop_node.output_port_map),
                         loop_node_output_port_map_len + 1)
        self.assertEqual(len(loop_node.out_ports()),
                         loop_node_out_ports_len + 1)
        self.assertEqual(
            loop_node.out_port(1).get_destination().node.op, 'Result')
        self.assertTrue(
            np.all(
                loop_node.out_port(1).data.get_shape() == int64_array(
                    [5, 10, 4, 64, 54])))
        last_node = Node(sub_graph_1, 'Loop_2')
        self.assertEqual(len(last_node.out_ports()), loop_2_out_ports_len)
        unsq_node = last_node.out_port(0).get_destinations()[1].node
        self.assertEqual(unsq_node.op, 'Unsqueeze')
        self.assertEqual(
            unsq_node.out_port(0).get_destination().node.op, 'Result')
        self.assertEqual(
            unsq_node.out_port(0).get_destination().node.internal_layer_id,
            max_layer_id + 3)
        self.assertTrue(
            np.all(
                unsq_node.out_port(0).data.get_shape() == int64_array(
                    [1, 10, 4, 64, 54])))