示例#1
0
    def test_deletion_3(self):
        nodes = {
            **shaped_const_with_data('input_0', [5, 3]),
            **shaped_const_with_data('input_1', [5, 1]),
            **shaped_const_with_data('input_2', [5, 5]),
            **shaped_const_with_data('input_3', [5, 0]),
            **regular_op_with_shaped_data('concat', [5, 9], {
                                              'type': 'Concat',
                                              'axis': 1
                                          }),
            **result(),
        }
        edges_before = [
            *connect('input_0', '0:concat'),
            *connect('input_1', '1:concat'),
            *connect('input_2', '2:concat'),
            *connect('input_3', '3:concat'),
            *connect('concat', 'output'),
        ]
        edges_after = [
            *connect('input_0', '0:concat'),
            *connect('input_1', '1:concat'),
            *connect('input_2', '2:concat'),
            *connect('concat', 'output'),
        ]

        graph = build_graph(nodes, edges_before, nodes_with_edges_only=True)
        ConcatOdInputEraserAndPortsReconnect().find_and_replace_pattern(graph)
        graph_ref = build_graph(nodes, edges_after, nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
示例#2
0
    def test_deletion_trailing_unconnected_ports(self):
        nodes = {
            **shaped_const_with_data('input_0', [5, 3]),
            **regular_op_with_shaped_data('concat', [5, 3], {
                                              'type': 'Concat',
                                              'axis': 1
                                          }),
            **result(),
        }
        edges_before = [
            *connect('input_0', '0:concat'),
            *connect('concat', 'output'),
        ]
        edges_after = [
            *connect('input_0', '0:concat'),
            *connect('concat', 'output'),
        ]

        graph = build_graph(nodes, edges_before, nodes_with_edges_only=True)
        Node(graph, 'concat').add_input_port(1)
        ConcatOdInputEraserAndPortsReconnect().find_and_replace_pattern(graph)
        graph_ref = build_graph(nodes, edges_after, nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
        self.assertTrue(1 not in Node(graph, 'concat').in_ports())
示例#3
0
    def test_negative(self):
        nodes = {
            **shaped_const_with_data('input_0', [1]),
            **shaped_const_with_data('input_1', [1]),
            **shaped_const_with_data('input_2', [1]),
            **shaped_const_with_data('input_3', [1]),
            **regular_op_with_shaped_data('concat', [4], {'type': 'Concat'}),
            **result(),
        }
        edges = [
            *connect('input_0', '0:concat'),
            *connect('input_1', '1:concat'),
            *connect('input_2', '2:concat'),
            *connect('input_3', '3:concat'),
            *connect('concat', 'output'),
        ]

        graph = build_graph(nodes, edges, nodes_with_edges_only=True)
        ConcatOdInputEraserAndPortsReconnect().find_and_replace_pattern(graph)
        graph_ref = build_graph(nodes, edges, nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
示例#4
0
    def test_assertion_error(self):
        nodes = {
            **shaped_const_with_data('input_0', [0]),
            **shaped_const_with_data('input_1', [0]),
            **shaped_const_with_data('input_2', [0]),
            **shaped_const_with_data('input_3', [0]),
            **regular_op_with_shaped_data('concat', [0], {'type': 'Concat'}),
            **result(),
        }
        edges = [
            *connect('input_0', '0:concat'),
            *connect('input_1', '1:concat'),
            *connect('input_2', '2:concat'),
            *connect('input_3', '3:concat'),
            *connect('concat', 'output'),
        ]

        graph = build_graph(nodes, edges, nodes_with_edges_only=True)
        self.assertRaises(AssertionError, ConcatOdInputEraserAndPortsReconnect().find_and_replace_pattern, graph)