Ejemplo n.º 1
0
    def test_no_steps_no_axes(self):
        input_shape = int64_array([5, 10, 20])
        starts_value = int64_array([3, 2, 7])
        ends_value = int64_array([5, 8, 15])
        steps_value = int64_array([1, 1, 1])
        masks_value = np.zeros([len(input_shape)], dtype=np.int64)
        graph = build_graph(self.nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data'),
                             ('placeholder_1_data', 'slice', {'in': 0}),
                             ('starts', 'starts_data'),
                             ('starts_data', 'slice', {'in': 1}),
                             ('ends', 'ends_data'),
                             ('ends_data', 'slice', {'in': 2}),
                             ('slice', 'slice_data'),
                             ('slice_data', 'output_op'),
                             ('output_op', 'output_data'),
                             ('output_data', 'op_output')
                             ],
                            {'placeholder_1_data': {'shape': input_shape},
                             'starts': {'shape': starts_value.shape, 'value': starts_value},
                             'starts_data': {'shape': starts_value.shape, 'value': starts_value},
                             'ends': {'shape': ends_value.shape, 'value': ends_value},
                             'ends_data': {'shape': ends_value.shape, 'value': ends_value},
                             }, nodes_with_edges_only=True
                            )
        slice_node = Node(graph, 'slice')
        Slice.infer(slice_node)

        pattern = ConvertSlice()
        pattern.find_and_replace_pattern(graph)

        ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
        assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'

        graph_ref = build_graph(self.nodes_attributes,
                                [('placeholder_1', 'placeholder_1_data'),
                                 ('placeholder_1_data', 'strided_slice', {'in': 0}),
                                 ('starts', 'starts_data'),
                                 ('starts_data', 'strided_slice', {'in': 1}),
                                 ('ends', 'ends_data'),
                                 ('ends_data', 'strided_slice', {'in': 2}),
                                 ('strides', 'strides_data'),
                                 ('strides_data', 'strided_slice', {'in': 3}),
                                 ('strided_slice', 'slice_data'),
                                 ('slice_data', 'output_op'),
                                 ('output_op', 'output_data'),
                                 ('output_data', 'op_output')
                                 ],
                                {'placeholder_1_data': {'shape': input_shape},
                                 'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value,
                                                   'ellipsis_mask': masks_value, 'begin_mask': np.ones([3]),
                                                   'end_mask': np.ones([3])},
                                 'slice_data': {'shape': int64_array([2, 6, 8])}
                                 }, nodes_with_edges_only=True
                                )
        (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True)
        self.assertTrue(flag, resp)
Ejemplo n.º 2
0
    def test_1(self):
        """
        Testing case with non-constant path and multiple
        slicing dimensions
        :return:
        """
        graph = build_graph(
            nodes_attributes,
            [('placeholder_1', 'placeholder_1_data'),
             ('placeholder_1_data', 'slice'), ('slice', 'slice_data'),
             ('slice_data', 'output_op'), ('output_op', 'output_data'),
             ('output_data', 'op_output')],
            {
                'placeholder_1_data': {
                    'shape': np.array([4, 5, 6])
                },
                'slice': {
                    'start': np.array([1, 2, 3]),
                    'end': np.array([3, 4, 4]),
                    'axis': None
                },
            },
            nodes_with_edges_only=True,
        )
        slice_node = Node(graph, 'slice')
        Slice.infer(slice_node)

        pattern = ConvertSlice()
        pattern.find_and_replace_pattern(graph)
        graph.clean_up()

        ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
        assert ss_node.type == 'Crop', 'Something wrong with transformed Slice node'

        graph_ref = build_graph(
            nodes_attributes,
            [('placeholder_1', 'placeholder_1_data'),
             ('placeholder_1_data', 'crop'), ('crop', 'slice_data'),
             ('slice_data', 'output_op'), ('output_op', 'output_data'),
             ('output_data', 'op_output')],
            {
                'placeholder_1_data': {
                    'shape': np.array([4, 5, 6])
                },
                'crop': {
                    'axis': np.array([0, 1, 2]),
                    'offset': np.array([1, 2, 3]),
                    'dim': np.array([2, 2, 1])
                },
            },
            nodes_with_edges_only=True,
        )
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output_op',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Ejemplo n.º 3
0
    def test_2(self):
        """
        Testing case with constant path and one
         slicing dimension
        """
        graph = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'slice'),
                               ('slice', 'slice_data'),
                               ('slice_data', 'output_op'),
                               ('output_op', 'output_data')], {
                                   'placeholder_1_data': {
                                       'shape': np.array([4, 5, 6])
                                   },
                                   'slice': {
                                       'start': np.array([1]),
                                       'end': np.array([3]),
                                       'axis': None
                                   },
                                   'output_op': {
                                       'is_output': True
                                   }
                               })
        slice_node = Node(graph, 'slice')
        Slice.infer(slice_node)

        pattern = ConvertSlice()
        pattern.find_and_replace_pattern(graph)

        graph_ref = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'strided_slice'),
                               ('strided_slice', 'slice_data'),
                               ('slice_data', 'output_op'),
                               ('output_op', 'output_data')],
            {
                'placeholder_1_data': {
                    'shape': np.array([4, 5, 6])
                },
                'strided_slice': {
                    'slices':
                    np.array([slice(1, 3, 1),
                              slice(0, 5, 1),
                              slice(0, 6, 1)]),
                    'shrink_axis_mask':
                    np.array([False, False, False])
                },
                'output_op': {
                    'is_output': True
                }
            })

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output_op',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Ejemplo n.º 4
0
    def test_1(self):
        """
        Testing case with non-constant path and multiple
        slicing dimensions
        :return:
        """
        graph = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'slice'),
                               ('slice', 'slice_data'),
                               ('slice_data', 'output_op'),
                               ('output_op', 'output_data')], {
                                   'placeholder_1_data': {
                                       'shape': np.array([4, 5, 6])
                                   },
                                   'slice': {
                                       'start': np.array([1, 2, 3]),
                                       'end': np.array([3, 4, 4]),
                                       'axis': None
                                   },
                                   'output_op': {
                                       'is_output': True
                                   },
                               })
        slice_node = Node(graph, 'slice')
        Slice.infer(slice_node)

        pattern = ConvertSlice()
        pattern.find_and_replace_pattern(graph)

        graph_ref = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'crop'),
                               ('crop', 'slice_data'),
                               ('slice_data', 'output_op'),
                               ('output_op', 'output_data')], {
                                   'placeholder_1_data': {
                                       'shape': np.array([4, 5, 6])
                                   },
                                   'crop': {
                                       'axis': np.array([0, 1, 2]),
                                       'offset': np.array([1, 2, 3]),
                                   },
                                   'output_op': {
                                       'is_output': True
                                   },
                                   'dim': {
                                       'dim': np.array([2, 2, 1])
                                   },
                               })
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output_op',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Ejemplo n.º 5
0
    def test_convert_slice_to_strided_slice_without_axes_and_steps(self):
        graph = build_graph(
            nodes_attrs=nodes_attributes,
            edges=[
                *connect('input:0', '0:slice'),
                *connect('starts:0', '1:slice'),
                *connect('ends:0', '2:slice'),
                *connect('slice:0', '0:result')
            ],
            update_attributes={
                'starts': {'value': int64_array([0, 0, 0, 0]), 'shape': [4]},
                'ends': {'value': int64_array([1, 2, 150, 150]), 'shape': [4]},
            },
            nodes_with_edges_only=True
        )

        ref_graph = build_graph(
            nodes_attrs=nodes_attributes,
            edges=pattern_ref_graph + [
                *connect('ss_begin_cast:0', '0:ss_begin_gather_0'),
                *connect('ss_begin_gather_0:0', '0:ss_begin_concat'),
                *connect_data('ss_begin_cast:0', '0:ss_begin_gather_1'),
                *connect('ss_begin_gather_1:0', '1:ss_begin_concat'),
                *connect_data('ss_begin_cast:0', '0:ss_begin_gather_2'),
                *connect('ss_begin_gather_2:0', '2:ss_begin_concat'),
                *connect_data('ss_begin_cast:0', '0:ss_begin_gather_3'),
                *connect('ss_begin_gather_3:0', '3:ss_begin_concat'),

                *connect('ss_end_cast:0', '0:ss_end_gather_0'),
                *connect('ss_end_gather_0:0', '0:ss_end_concat'),
                *connect_data('ss_end_cast:0', '0:ss_end_gather_1'),
                *connect('ss_end_gather_1:0', '1:ss_end_concat'),
                *connect_data('ss_end_cast:0', '0:ss_end_gather_2'),
                *connect('ss_end_gather_2:0', '2:ss_end_concat'),
                *connect_data('ss_end_cast:0', '0:ss_end_gather_3'),
                *connect('ss_end_gather_3:0', '3:ss_end_concat'),
            ],
            update_attributes={
                'starts': {'value': int64_array([0, 0, 0, 0]), 'shape': [4]},
                'ends': {'value': int64_array([1, 2, 150, 150]), 'shape': [4]},
                'ss_strides': {'value': int64_array([1, 1, 1, 1]), 'shape': [4]},
                'ss': {'begin_mask': int64_array([1, 1, 1, 1]), 'end_mask': int64_array([1, 1, 1, 1])}
            }
        )
        ConvertSlice().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, ref_graph, 'result', check_op_attrs=True)
        self.assertTrue(flag, resp)
Ejemplo n.º 6
0
    def test_3(self):
        """
        Testing case with constant path and one
         slicing dimension
        """
        graph = build_graph(
            nodes_attributes,
            [('placeholder_1', 'placeholder_1_data'),
             ('placeholder_1_data', 'slice'), ('slice', 'slice_data'),
             ('slice_data', 'output_op'), ('output_op', 'output_data'),
             ('output_data', 'op_output')],
            {
                'placeholder_1_data': {
                    'shape': np.array([1, 5, 6])
                },
                'slice': {
                    'start': np.array([1]),
                    'end': np.array([3]),
                    'axis': np.array([1])
                }
            },
            nodes_with_edges_only=True,
        )
        graph.graph['layout'] = 'NHWC'
        slice_node = Node(graph, 'slice')
        Slice.infer(slice_node)

        pattern = ConvertSlice()
        pattern.find_and_replace_pattern(graph)
        graph.clean_up()

        ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
        assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'

        graph_ref = build_graph(
            nodes_attributes,
            [('placeholder_1', 'placeholder_1_data'),
             ('placeholder_2', 'placeholder_2_data'),
             ('placeholder_3', 'placeholder_3_data'),
             ('placeholder_1_data', 'strided_slice'),
             ('placeholder_2_data', 'strided_slice'),
             ('placeholder_3_data', 'strided_slice'),
             ('strided_slice', 'slice_data'), ('slice_data', 'output_op'),
             ('output_op', 'output_data'), ('output_data', 'op_output')],
            {
                'placeholder_1_data': {
                    'shape': np.array([1, 5, 6])
                },
                'strided_slice': {
                    'slices':
                    np.array([slice(0, 1, 1),
                              slice(1, 3, 1),
                              slice(0, 6, 1)]),
                    'shrink_axis_mask':
                    np.array([False, False, False])
                },
            },
            nodes_with_edges_only=True,
        )

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output_op',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Ejemplo n.º 7
0
 def test_convert_slice_to_strided_slice(self, input_shape, start, end,
                                         axes, steps, ss_begin_parts: tuple,
                                         ss_end_parts: tuple, ss_steps,
                                         ss_begin_mask, ss_end_mask):
     graph = build_graph(
         nodes_attrs={
             **regular_op_with_shaped_data('input', input_shape, {
                 'type': 'Parameter'
             }),
             **valued_const_with_data('start', start),
             **valued_const_with_data('end', end),
             **valued_const_with_data('axes', axes),
             **valued_const_with_data('steps', steps),
             **regular_op_with_empty_data('slice', {
                 'type': None,
                 'op': 'Slice'
             }),
             **result('result')
         },
         edges=[
             *connect('input', 'slice'), *connect('start', '1:slice'),
             *connect('end', '2:slice'), *connect('axes', '3:slice'),
             *connect('steps', '4:slice'), *connect('slice', 'result')
         ])
     ref_graph = build_graph(nodes_attrs={
         **regular_op_with_shaped_data('input', input_shape, {
             'type': 'Parameter'
         }),
         **valued_const_with_data('start', start),
         **valued_const_with_data('begin_first_part', ss_begin_parts[0]),
         **valued_const_with_data('begin_last_part', ss_begin_parts[1]),
         **regular_op_with_empty_data('convert_start', {
             'op': 'Cast',
             'type': 'Convert',
             'dst_type': np.int64
         }),
         **regular_op_with_empty_data('ss_begin', {
             'type': 'Concat',
             'op': 'Concat',
             'axis': 0
         }),
         **valued_const_with_data('end', end),
         **valued_const_with_data('end_first_part', ss_end_parts[0]),
         **valued_const_with_data('end_last_part', ss_end_parts[1]),
         **regular_op_with_empty_data('convert_end', {
             'op': 'Cast',
             'type': 'Convert',
             'dst_type': np.int64
         }),
         **regular_op_with_empty_data('ss_end', {
             'type': 'Concat',
             'op': 'Concat',
             'axis': 0
         }),
         **const('ss_steps', ss_steps),
         **empty_data('ss_steps_d'),
         **regular_op_with_empty_data(
             'ss', {
                 'op': 'StridedSlice',
                 'type': 'StridedSlice',
                 'begin_mask': ss_begin_mask,
                 'end_mask': ss_end_mask,
                 'new_axis_mask': np.zeros(len(input_shape), dtype=np.int64),
                 'shrink_axis_mask': np.zeros(len(input_shape),
                                              dtype=np.int64),
                 'ellipsis_mask': np.zeros(len(input_shape), dtype=np.int64)
             }),
         **result('result')
     },
                             edges=[
                                 *connect('input', 'ss'),
                                 *connect('begin_first_part', 'ss_begin'),
                                 *connect('start', 'convert_start'),
                                 *connect('convert_start', '1:ss_begin'),
                                 *connect('begin_last_part', '2:ss_begin'),
                                 *connect('ss_begin', '1:ss'),
                                 *connect('end_first_part', 'ss_end'),
                                 *connect('end', 'convert_end'),
                                 *connect('convert_end', '1:ss_end'),
                                 *connect('end_last_part', '2:ss_end'),
                                 *connect('ss_end', '2:ss'),
                                 *connect('ss_steps', '3:ss'),
                                 *connect('ss', 'result')
                             ])
     ConvertSlice().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph,
                                   ref_graph,
                                   'result',
                                   check_op_attrs=True)
     self.assertTrue(flag, resp)
Ejemplo n.º 8
0
 def test_convert_slice_to_strided_slice_without_axes_and_steps(self):
     graph = build_graph(nodes_attrs={
         **regular_op_with_shaped_data('input', int64_array([2, 5, 10]), {
                                           'type': 'Parameter'
                                       }),
         **valued_const_with_data('start', np.array([0, 0, 0])),
         **valued_const_with_data('end', np.array([1, 3, 5])),
         **regular_op_with_empty_data('slice', {
             'type': None,
             'op': 'Slice'
         }),
         **result('result')
     },
                         edges=[
                             *connect('input', 'slice'),
                             *connect('start', '1:slice'),
                             *connect('end', '2:slice'),
                             *connect('slice', 'result')
                         ])
     ref_graph = build_graph(nodes_attrs={
         **regular_op_with_shaped_data('input', int64_array([2, 5, 10]), {
                                           'type': 'Parameter'
                                       }),
         **valued_const_with_data('start', np.array([0, 0, 0])),
         **valued_const_with_data('begin_first_part', int64_array([])),
         **valued_const_with_data('begin_last_part', int64_array([])),
         **regular_op_with_empty_data('convert_start', {
             'op': 'Cast',
             'type': 'Convert',
             'dst_type': np.int64
         }),
         **regular_op_with_empty_data('ss_begin', {
             'type': 'Concat',
             'op': 'Concat',
             'axis': 0
         }),
         **valued_const_with_data('end', np.array([1, 3, 5])),
         **valued_const_with_data('end_first_part', int64_array([])),
         **valued_const_with_data('end_last_part', int64_array([])),
         **regular_op_with_empty_data('convert_end', {
             'op': 'Cast',
             'type': 'Convert',
             'dst_type': np.int64
         }),
         **regular_op_with_empty_data('ss_end', {
             'type': 'Concat',
             'op': 'Concat',
             'axis': 0
         }),
         **const('ss_steps', int64_array([1, 1, 1])),
         **empty_data('ss_steps_d'),
         **regular_op_with_empty_data(
             'ss', {
                 'op': 'StridedSlice',
                 'type': 'StridedSlice',
                 'begin_mask': int64_array([1, 1, 1]),
                 'end_mask': int64_array([1, 1, 1]),
                 'new_axis_mask': np.zeros(3, dtype=np.int64),
                 'shrink_axis_mask': np.zeros(3, dtype=np.int64),
                 'ellipsis_mask': np.zeros(3, dtype=np.int64)
             }),
         **result('result')
     },
                             edges=[
                                 *connect('input', 'ss'),
                                 *connect('begin_first_part', 'ss_begin'),
                                 *connect('start', 'convert_start'),
                                 *connect('convert_start', '1:ss_begin'),
                                 *connect('begin_last_part', '2:ss_begin'),
                                 *connect('ss_begin', '1:ss'),
                                 *connect('end_first_part', 'ss_end'),
                                 *connect('end', 'convert_end'),
                                 *connect('convert_end', '1:ss_end'),
                                 *connect('end_last_part', '2:ss_end'),
                                 *connect('ss_end', '2:ss'),
                                 *connect('ss_steps', '3:ss'),
                                 *connect('ss', 'result')
                             ])
     ConvertSlice().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph,
                                   ref_graph,
                                   'result',
                                   check_op_attrs=True)
     self.assertTrue(flag, resp)