Exemple #1
0
def if_create_main_graph():
    sub_graph_2 = build_graph(nodes_attrs=if_sub_graph_2_then_nodes,
                              edges=[
                                  *connect('in_2_int', 'OUT_2'),
                                  *connect('ones', 'OUT_2'),
                                  *connect('OUT_2', 'OUT_2_out')
                              ],
                              nodes_with_edges_only=True)

    sub_graph_2_else = build_graph(nodes_attrs=if_sub_graph_2_else_nodes,
                                   edges=[
                                       *connect('in_2_int_else', 'OUT_2_else'),
                                       *connect('ones_else', 'OUT_2_else'),
                                       *connect('OUT_2_else', 'OUT_2_out_else')
                                   ],
                                   nodes_with_edges_only=True)

    sub_graph_1 = build_graph(nodes_attrs=if_sub_graph_1_then_nodes,
                              edges=[
                                  *connect('cond_2', '0:If_2'),
                                  *connect('IN_2', '1:If_2'),
                                  *connect('If_2:0', 'If_2_out'),
                                  *connect('in_1_int', 'in_1_int_out')
                              ],
                              nodes_with_edges_only=True)
    if_node_1 = Node(sub_graph_1, 'If_2')
    if_node_1.then_graph = sub_graph_2
    if_node_1.else_graph = sub_graph_2_else

    return sub_graph_1
Exemple #2
0
    def test_add_output_1(self):
        sub_graph_1 = if_create_main_graph()
        if_node_1 = Node(sub_graph_1, 'If_2')

        sub_graph_1_else = build_graph(
            nodes_attrs=if_sub_graph_1_else_nodes,
            edges=[*connect('in_1_int', 'in_1_int_out')],
            nodes_with_edges_only=True)

        main_graph = build_graph(nodes_attrs=if_main_graph_nodes,
                                 edges=[
                                     *connect('cond', '0:If'),
                                     *connect('IN_1', '1:If'),
                                     *connect('IN_2', "2:If"),
                                     *connect('If:0', 'OUT_1')
                                 ],
                                 nodes_with_edges_only=True)
        if_node = Node(main_graph, 'If')
        if_node.then_graph = sub_graph_1
        if_node.else_graph = sub_graph_1_else
        if_node_out_ports_len = len(if_node.out_ports())
        if_2_node_out_ports_len = len(if_node_1.out_ports())

        main_graph.graph['additional_outputs'] = ['If', ['If_2', 'in_1_int']]

        AddOutputRecursive().find_and_replace_pattern(main_graph)
        if_node = Node(main_graph, 'If')
        self.assertEqual(len(if_node.out_ports()), if_node_out_ports_len + 1)
        self.assertEqual(
            if_node.out_port(1).get_destination().node.op, 'Result')
        self.assertTrue(
            np.all(
                if_node.out_port(1).data.get_shape() == int64_array(
                    [1, 4, 64, 54])))
        last_node = Node(sub_graph_1, 'If_2')
        self.assertEqual(len(last_node.out_ports()), if_2_node_out_ports_len)
        self.assertEqual(
            last_node.out_port(0).get_destinations()[1].node.op, 'Result')
        self.assertTrue(
            np.all(
                last_node.out_port(0).data.get_shape() == int64_array(
                    [1, 4, 64, 54])))