Ejemplo n.º 1
0
    def test_1(self):
        """
        Testing case with non-constant path and multiple
        slicing dimensions
        :return:
        """
        graph = build_graph(
            nodes_attributes,
            [('placeholder_1', 'placeholder_1_data'),
             ('placeholder_1_data', 'slice'), ('slice', 'slice_data'),
             ('slice_data', 'output_op'), ('output_op', 'output_data'),
             ('output_data', 'op_output')],
            {
                'placeholder_1_data': {
                    'shape': np.array([4, 5, 6])
                },
                'slice': {
                    'start': np.array([1, 2, 3]),
                    'end': np.array([3, 4, 4]),
                    'axis': None
                },
            },
            nodes_with_edges_only=True,
        )
        slice_node = Node(graph, 'slice')
        Slice.infer(slice_node)

        pattern = ConvertSlice()
        pattern.find_and_replace_pattern(graph)
        graph.clean_up()

        ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
        assert ss_node.type == 'Crop', 'Something wrong with transformed Slice node'

        graph_ref = build_graph(
            nodes_attributes,
            [('placeholder_1', 'placeholder_1_data'),
             ('placeholder_1_data', 'crop'), ('crop', 'slice_data'),
             ('slice_data', 'output_op'), ('output_op', 'output_data'),
             ('output_data', 'op_output')],
            {
                'placeholder_1_data': {
                    'shape': np.array([4, 5, 6])
                },
                'crop': {
                    'axis': np.array([0, 1, 2]),
                    'offset': np.array([1, 2, 3]),
                    'dim': np.array([2, 2, 1])
                },
            },
            nodes_with_edges_only=True,
        )
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output_op',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Ejemplo n.º 2
0
    def test_no_steps_no_axes(self):
        input_shape = int64_array([5, 10, 20])
        starts_value = int64_array([3, 2, 7])
        ends_value = int64_array([5, 8, 15])
        steps_value = int64_array([1, 1, 1])
        masks_value = np.zeros([len(input_shape)], dtype=np.int64)
        graph = build_graph(self.nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data'),
                             ('placeholder_1_data', 'slice', {'in': 0}),
                             ('starts', 'starts_data'),
                             ('starts_data', 'slice', {'in': 1}),
                             ('ends', 'ends_data'),
                             ('ends_data', 'slice', {'in': 2}),
                             ('slice', 'slice_data'),
                             ('slice_data', 'output_op'),
                             ('output_op', 'output_data'),
                             ('output_data', 'op_output')
                             ],
                            {'placeholder_1_data': {'shape': input_shape},
                             'starts': {'shape': starts_value.shape, 'value': starts_value},
                             'starts_data': {'shape': starts_value.shape, 'value': starts_value},
                             'ends': {'shape': ends_value.shape, 'value': ends_value},
                             'ends_data': {'shape': ends_value.shape, 'value': ends_value},
                             }, nodes_with_edges_only=True
                            )
        slice_node = Node(graph, 'slice')
        Slice.infer(slice_node)

        pattern = ConvertSlice()
        pattern.find_and_replace_pattern(graph)

        ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
        assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'

        graph_ref = build_graph(self.nodes_attributes,
                                [('placeholder_1', 'placeholder_1_data'),
                                 ('placeholder_1_data', 'strided_slice', {'in': 0}),
                                 ('starts', 'starts_data'),
                                 ('starts_data', 'strided_slice', {'in': 1}),
                                 ('ends', 'ends_data'),
                                 ('ends_data', 'strided_slice', {'in': 2}),
                                 ('strides', 'strides_data'),
                                 ('strides_data', 'strided_slice', {'in': 3}),
                                 ('strided_slice', 'slice_data'),
                                 ('slice_data', 'output_op'),
                                 ('output_op', 'output_data'),
                                 ('output_data', 'op_output')
                                 ],
                                {'placeholder_1_data': {'shape': input_shape},
                                 'strided_slice': {'new_axis_mask': masks_value, 'shrink_axis_mask': masks_value,
                                                   'ellipsis_mask': masks_value, 'begin_mask': np.ones([3]),
                                                   'end_mask': np.ones([3])},
                                 'slice_data': {'shape': int64_array([2, 6, 8])}
                                 }, nodes_with_edges_only=True
                                )
        (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True)
        self.assertTrue(flag, resp)
Ejemplo n.º 3
0
    def test_2(self):
        """
        Testing case with constant path and one
         slicing dimension
        """
        graph = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'slice'),
                               ('slice', 'slice_data'),
                               ('slice_data', 'output_op'),
                               ('output_op', 'output_data')], {
                                   'placeholder_1_data': {
                                       'shape': np.array([4, 5, 6])
                                   },
                                   'slice': {
                                       'start': np.array([1]),
                                       'end': np.array([3]),
                                       'axis': None
                                   },
                                   'output_op': {
                                       'is_output': True
                                   }
                               })
        slice_node = Node(graph, 'slice')
        Slice.infer(slice_node)

        pattern = ConvertSlice()
        pattern.find_and_replace_pattern(graph)

        graph_ref = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'strided_slice'),
                               ('strided_slice', 'slice_data'),
                               ('slice_data', 'output_op'),
                               ('output_op', 'output_data')],
            {
                'placeholder_1_data': {
                    'shape': np.array([4, 5, 6])
                },
                'strided_slice': {
                    'slices':
                    np.array([slice(1, 3, 1),
                              slice(0, 5, 1),
                              slice(0, 6, 1)]),
                    'shrink_axis_mask':
                    np.array([False, False, False])
                },
                'output_op': {
                    'is_output': True
                }
            })

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output_op',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Ejemplo n.º 4
0
    def test_1(self):
        """
        Testing case with non-constant path and multiple
        slicing dimensions
        :return:
        """
        graph = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'slice'),
                               ('slice', 'slice_data'),
                               ('slice_data', 'output_op'),
                               ('output_op', 'output_data')], {
                                   'placeholder_1_data': {
                                       'shape': np.array([4, 5, 6])
                                   },
                                   'slice': {
                                       'start': np.array([1, 2, 3]),
                                       'end': np.array([3, 4, 4]),
                                       'axis': None
                                   },
                                   'output_op': {
                                       'is_output': True
                                   },
                               })
        slice_node = Node(graph, 'slice')
        Slice.infer(slice_node)

        pattern = ConvertSlice()
        pattern.find_and_replace_pattern(graph)

        graph_ref = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'crop'),
                               ('crop', 'slice_data'),
                               ('slice_data', 'output_op'),
                               ('output_op', 'output_data')], {
                                   'placeholder_1_data': {
                                       'shape': np.array([4, 5, 6])
                                   },
                                   'crop': {
                                       'axis': np.array([0, 1, 2]),
                                       'offset': np.array([1, 2, 3]),
                                   },
                                   'output_op': {
                                       'is_output': True
                                   },
                                   'dim': {
                                       'dim': np.array([2, 2, 1])
                                   },
                               })
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output_op',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Ejemplo n.º 5
0
    def test_3(self):
        """
        Testing case with constant path and one
         slicing dimension
        """
        graph = build_graph(
            nodes_attributes,
            [('placeholder_1', 'placeholder_1_data'),
             ('placeholder_1_data', 'slice'), ('slice', 'slice_data'),
             ('slice_data', 'output_op'), ('output_op', 'output_data'),
             ('output_data', 'op_output')],
            {
                'placeholder_1_data': {
                    'shape': np.array([1, 5, 6])
                },
                'slice': {
                    'start': np.array([1]),
                    'end': np.array([3]),
                    'axis': np.array([1])
                }
            },
            nodes_with_edges_only=True,
        )
        graph.graph['layout'] = 'NHWC'
        slice_node = Node(graph, 'slice')
        Slice.infer(slice_node)

        pattern = ConvertSlice()
        pattern.find_and_replace_pattern(graph)
        graph.clean_up()

        ss_node = Node(graph, graph.get_node_id_by_name('slice_node'))
        assert ss_node.type == 'StridedSlice', 'Something wrong with transformed Slice node'

        graph_ref = build_graph(
            nodes_attributes,
            [('placeholder_1', 'placeholder_1_data'),
             ('placeholder_2', 'placeholder_2_data'),
             ('placeholder_3', 'placeholder_3_data'),
             ('placeholder_1_data', 'strided_slice'),
             ('placeholder_2_data', 'strided_slice'),
             ('placeholder_3_data', 'strided_slice'),
             ('strided_slice', 'slice_data'), ('slice_data', 'output_op'),
             ('output_op', 'output_data'), ('output_data', 'op_output')],
            {
                'placeholder_1_data': {
                    'shape': np.array([1, 5, 6])
                },
                'strided_slice': {
                    'slices':
                    np.array([slice(0, 1, 1),
                              slice(1, 3, 1),
                              slice(0, 6, 1)]),
                    'shrink_axis_mask':
                    np.array([False, False, False])
                },
            },
            nodes_with_edges_only=True,
        )

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output_op',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)