Exemplo n.º 1
0
    def replace_pattern(graph: Graph, match: dict):
        node_ss = match['strided_slice']
        # slices = [elem for elem in node_ss.slices if elem is not None]
        # node_ss.slices = np.array(slices)

        if node_ss.out_port(0).data.get_value() is not None:
            # StridedSlices(SS) in shape-calculating sub-graphs that should not be deleted that easily
            # Example:
            # In RetinaNetFilteredDetectionsReplacement we have SS that slices first batch
            # We delete such SS for batch 1, but it should be performed while reshaping the model
            return

        output_data_node = node_ss.out_node(0)
        input_data_node = node_ss.in_node(0)

        out_shape = output_data_node.shape

        if not np.all(node_ss.shrink_axis_mask == 0):
            out_shape = list(out_shape)
            for i in range(len(node_ss.shrink_axis_mask)):
                if node_ss.shrink_axis_mask[i] == 1:
                    out_shape.insert(i, 1)
            out_shape = int64_array(out_shape)

        if not np.all(node_ss.new_axis_mask == 0):
            out_shape = list(out_shape)
            for i in reversed(range(len(node_ss.new_axis_mask))):
                if node_ss.new_axis_mask[i] == 1:
                    out_shape.pop(i)
            out_shape = int64_array(out_shape)

        if np.array_equal(input_data_node.shape, out_shape) and \
                all(elem.step == 1 for elem in match['strided_slice'].slices):
            if not np.all(node_ss.shrink_axis_mask == 0):
                ConvertGroupedStridedSlice.add_squeeze_for_shrink(
                    graph, node_ss)
            if not np.all(node_ss.new_axis_mask == 0):
                ConvertGroupedStridedSlice.add_unsqueeze_for_new(
                    graph, node_ss)

            log.info("Useless StridedSlice op '{}' has been detected".format(
                match['strided_slice'].id))
            # remove inputs to Strided Slice so it has just one input with data so we can use 'remove_op_node' function
            graph.remove_edge(match['strided_slice'].in_node(1).id,
                              match['strided_slice'].id)
            graph.remove_edge(match['strided_slice'].in_node(2).id,
                              match['strided_slice'].id)
            if len(match['strided_slice'].in_nodes()) > 3:
                graph.remove_edge(match['strided_slice'].in_node(3).id,
                                  match['strided_slice'].id)

            remove_op_node_with_data_node(graph, match['strided_slice'])
Exemplo n.º 2
0
    def replace_pattern(graph: Graph, match: dict):
        node_ss = match['strided_slice']
        output_data_node = node_ss.out_node(0)
        input_data_node = node_ss.in_node(0)

        out_shape = output_data_node.shape

        if not np.all(node_ss.shrink_axis_mask == 0):
            out_shape = list(out_shape)
            for i in range(len(node_ss.shrink_axis_mask)):
                if node_ss.shrink_axis_mask[i] == 1:
                    out_shape.insert(i, 1)
            out_shape = int64_array(out_shape)

        if not np.all(node_ss.new_axis_mask == 0):
            out_shape = list(out_shape)
            for i in reversed(range(len(node_ss.new_axis_mask))):
                if node_ss.new_axis_mask[i] == 1:
                    out_shape.pop(i)
            out_shape = int64_array(out_shape)

        if np.array_equal(input_data_node.shape, out_shape) and \
                all(elem.step == 1 for elem in match['strided_slice'].slices):
            if not np.all(node_ss.shrink_axis_mask == 0):
                ConvertGroupedStridedSlice.add_squeeze_for_shrink(
                    graph, node_ss)
            if not np.all(node_ss.new_axis_mask == 0):
                ConvertGroupedStridedSlice.add_unsqueeze_for_new(
                    graph, node_ss)

            log.info("Useless StridedSlice op '{}' has been detected".format(
                match['strided_slice'].id))
            # remove inputs to Strided Slice so it has just one input with data so we can use 'remove_op_node' function
            graph.remove_edge(match['strided_slice'].in_node(1).id,
                              match['strided_slice'].id)
            graph.remove_edge(match['strided_slice'].in_node(2).id,
                              match['strided_slice'].id)
            if len(match['strided_slice'].in_nodes()) > 3:
                graph.remove_edge(match['strided_slice'].in_node(3).id,
                                  match['strided_slice'].id)

            remove_op_node_with_data_node(graph, match['strided_slice'])
Exemplo n.º 3
0
    def test_8(self):
        graph = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'sslice_1'),
                               ('sslice_1', 'sslice_1_data'),
                               ('placeholder_1_data', 'sslice_2'),
                               ('sslice_2', 'sslice_2_data'),
                               ('sslice_1_data', 'concat_1'),
                               ('sslice_2_data', 'concat_1'),
                               ('concat_1', 'concat_1_data')], {
                                   'placeholder_1_data': {
                                       'shape': np.array([1, 54, 54, 3])
                                   },
                                   'sslice_1': {
                                       'slices':
                                       np.array([
                                           slice(0, 1, 1),
                                           slice(0, 18, 1),
                                           slice(0, 54, 1),
                                           slice(0, 3, 1)
                                       ])
                                   },
                                   'sslice_1_data': {
                                       'shape': np.array([1, 18, 54, 3])
                                   },
                                   'sslice_2': {
                                       'slices':
                                       np.array([
                                           slice(0, 1, 1),
                                           slice(18, 36, 1),
                                           slice(0, 54, 1),
                                           slice(0, 3, 1)
                                       ])
                                   },
                                   'sslice_2_data': {
                                       'shape': np.array([1, 18, 54, 3])
                                   },
                                   'concat_1_data': {
                                       'shape': np.array([1, 54, 54, 3]),
                                       'is_output': True
                                   },
                               })
        graph.graph['layout'] = 'NHWC'

        graph_ref = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'split_1'),
                               ('split_1', 'split_1_data'),
                               ('split_1', 'split_2_data'),
                               ('split_1', 'split_3_data'),
                               ('split_1_data', 'concat_1'),
                               ('split_3_data', 'concat_1'),
                               ('concat_1', 'concat_1_data')], {
                                   'placeholder_1_data': {
                                       'shape': np.array([1, 54, 54, 3])
                                   },
                                   'split_1': {
                                       'axis': 1
                                   },
                                   'split_1_data': {
                                       'shape': np.array([1, 18, 54, 3])
                                   },
                                   'split_2_data': {
                                       'shape': np.array([1, 18, 54, 3])
                                   },
                                   'split_3_data': {
                                       'shape': np.array([1, 18, 54, 3])
                                   },
                                   'concat_1_data': {
                                       'shape': np.array([1, 54, 54, 3]),
                                       'is_output': True
                                   },
                               })

        pattern = ConvertGroupedStridedSlice()
        pattern.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'concat_1_data',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)