Beispiel #1
0
    def test_single_stride_slice_removal(self):
        graph = build_graph(nodes_attributes, [
            ('placeholder', 'placeholder_data'),
            ('placeholder_data', 'strided_slice'),
            ('strided_slice_input_1_data', 'strided_slice'),
            ('strided_slice_input_2_data', 'strided_slice'),
            ('strided_slice_input_3_data', 'strided_slice'),
            ('strided_slice', 'strided_slice_data'),
            ('strided_slice_data', 'output_op'),
        ], {},
                            nodes_with_edges_only=True)

        pattern = UselessStridedSliceEraser()
        pattern.find_and_replace_pattern(graph)

        graph_ref = build_graph(nodes_attributes, [
            ('placeholder', 'placeholder_data'),
            ('placeholder_data', 'output_op'),
        ], {'placeholder_data': {
            'shape': np.array([4, 5, 6])
        }})
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output_op',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
    def test_consecutive_stride_slices_removal(self):
        graph = build_graph(nodes_attributes, [
            ('placeholder', 'placeholder_data'),
            ('placeholder_data', 'strided_slice'),
            ('strided_slice_input_1_data', 'strided_slice'),
            ('strided_slice_input_2_data', 'strided_slice'),
            ('strided_slice_input_3_data', 'strided_slice'),
            ('strided_slice', 'strided_slice_data'),
            ('strided_slice_data', 'strided_slice_2'),
            ('strided_slice_input_1_data', 'strided_slice_2'),
            ('strided_slice_input_2_data', 'strided_slice_2'),
            ('strided_slice_input_3_data', 'strided_slice_2'),
            ('strided_slice_2', 'strided_slice_2_data'),
            ('strided_slice_2_data', 'output_op'),
        ], {},
                            nodes_with_edges_only=True)

        UselessStridedSliceEraser().find_and_replace_pattern(graph)
        shape_inference(graph)

        graph_ref = build_graph(nodes_attributes, [
            ('placeholder', 'placeholder_data'),
            ('placeholder_data', 'output_op'),
        ], {'placeholder_data': {
            'shape': int64_array([4, 1, 6])
        }})
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output_op',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
    def test_single_stride_slice_with_shrink_and_new_removal(self):
        graph = build_graph(nodes_attributes, [
            ('placeholder', 'placeholder_data'),
            ('placeholder_data', 'strided_slice'),
            ('strided_slice_input_1_data', 'strided_slice'),
            ('strided_slice_input_2_data', 'strided_slice'),
            ('strided_slice_input_3_data', 'strided_slice'),
            ('strided_slice', 'strided_slice_data'),
            ('strided_slice_data', 'output_op'),
        ], {
            'strided_slice': {
                'shrink_axis_mask': int64_array([0, 1, 0, 0]),
                'new_axis_mask': int64_array([0, 0, 1, 0])
            },
            'strided_slice_data': {
                'shape': int64_array([4, 1, 6])
            }
        },
                            nodes_with_edges_only=True)
        graph.graph['layout'] = 'NCHW'

        UselessStridedSliceEraser().find_and_replace_pattern(graph)
        shape_inference(graph)

        graph_ref = build_graph(nodes_attributes,
                                [('placeholder', 'placeholder_data'),
                                 ('placeholder_data', 'unsqueeze'),
                                 ('unsqueeze_const', 'unsqueeze_const_data'),
                                 ('unsqueeze_const_data', 'unsqueeze'),
                                 ('unsqueeze', 'unsqueeze_data'),
                                 ('unsqueeze_data', 'squeeze'),
                                 ('squeeze_const', 'squeeze_const_data'),
                                 ('squeeze_const_data', 'squeeze'),
                                 ('squeeze', 'strided_slice_data'),
                                 ('strided_slice_data', 'output_op')], {
                                     'placeholder_data': {
                                         'shape': int64_array([4, 1, 6])
                                     },
                                     'unsqueeze_data': {
                                         'shape': int64_array([4, 1, 1, 6])
                                     },
                                     'strided_slice_data': {
                                         'shape': int64_array([4, 1, 6])
                                     },
                                     'unsqueeze_const': {
                                         'value': int64_array([2])
                                     },
                                 },
                                nodes_with_edges_only=True)
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output_op',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
 def test_negative(self):
     graph = build_graph(nodes_attributes, [
         ('placeholder', 'placeholder_data'),
         ('placeholder_data', 'strided_slice'),
         ('strided_slice_input_1_data', 'strided_slice'),
         ('strided_slice_input_2_data', 'strided_slice'),
         ('strided_slice_input_3_data', 'strided_slice'),
         ('strided_slice', 'strided_slice_data'),
         ('strided_slice_data', 'output_op'),
     ], {'strided_slice_data': {
         'value': []
     }},
                         nodes_with_edges_only=True)
     graph_ref = graph.copy()
     UselessStridedSliceEraser().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph,
                                   graph_ref,
                                   'output_op',
                                   check_op_attrs=True)
     self.assertTrue(flag, resp)