Exemplo n.º 1
0
    def replace_pattern(graph: Graph, match: dict):
        node_ss = match['strided_slice']
        # slices = [elem for elem in node_ss.slices if elem is not None]
        # node_ss.slices = np.array(slices)

        if node_ss.out_port(0).data.get_value() is not None:
            # StridedSlices(SS) in shape-calculating sub-graphs that should not be deleted that easily
            # Example:
            # In RetinaNetFilteredDetectionsReplacement we have SS that slices first batch
            # We delete such SS for batch 1, but it should be performed while reshaping the model
            return

        output_data_node = node_ss.out_node(0)
        input_data_node = node_ss.in_node(0)

        out_shape = output_data_node.shape

        if not np.all(node_ss.shrink_axis_mask == 0):
            out_shape = list(out_shape)
            for i in range(len(node_ss.shrink_axis_mask)):
                if node_ss.shrink_axis_mask[i] == 1:
                    out_shape.insert(i, 1)
            out_shape = int64_array(out_shape)

        if not np.all(node_ss.new_axis_mask == 0):
            out_shape = list(out_shape)
            for i in reversed(range(len(node_ss.new_axis_mask))):
                if node_ss.new_axis_mask[i] == 1:
                    out_shape.pop(i)
            out_shape = int64_array(out_shape)

        if np.array_equal(input_data_node.shape, out_shape) and \
                all(elem.step == 1 for elem in match['strided_slice'].slices):
            if not np.all(node_ss.shrink_axis_mask == 0):
                ConvertGroupedStridedSlice.add_squeeze_for_shrink(
                    graph, node_ss)
            if not np.all(node_ss.new_axis_mask == 0):
                ConvertGroupedStridedSlice.add_unsqueeze_for_new(
                    graph, node_ss)

            log.info("Useless StridedSlice op '{}' has been detected".format(
                match['strided_slice'].id))
            # remove inputs to Strided Slice so it has just one input with data so we can use 'remove_op_node' function
            graph.remove_edge(match['strided_slice'].in_node(1).id,
                              match['strided_slice'].id)
            graph.remove_edge(match['strided_slice'].in_node(2).id,
                              match['strided_slice'].id)
            if len(match['strided_slice'].in_nodes()) > 3:
                graph.remove_edge(match['strided_slice'].in_node(3).id,
                                  match['strided_slice'].id)

            remove_op_node_with_data_node(graph, match['strided_slice'])
Exemplo n.º 2
0
    def replace_pattern(graph: Graph, match: dict):
        node_ss = match['strided_slice']
        output_data_node = node_ss.out_node(0)
        input_data_node = node_ss.in_node(0)

        out_shape = output_data_node.shape

        if not np.all(node_ss.shrink_axis_mask == 0):
            out_shape = list(out_shape)
            for i in range(len(node_ss.shrink_axis_mask)):
                if node_ss.shrink_axis_mask[i] == 1:
                    out_shape.insert(i, 1)
            out_shape = int64_array(out_shape)

        if not np.all(node_ss.new_axis_mask == 0):
            out_shape = list(out_shape)
            for i in reversed(range(len(node_ss.new_axis_mask))):
                if node_ss.new_axis_mask[i] == 1:
                    out_shape.pop(i)
            out_shape = int64_array(out_shape)

        if np.array_equal(input_data_node.shape, out_shape) and \
                all(elem.step == 1 for elem in match['strided_slice'].slices):
            if not np.all(node_ss.shrink_axis_mask == 0):
                ConvertGroupedStridedSlice.add_squeeze_for_shrink(
                    graph, node_ss)
            if not np.all(node_ss.new_axis_mask == 0):
                ConvertGroupedStridedSlice.add_unsqueeze_for_new(
                    graph, node_ss)

            log.info("Useless StridedSlice op '{}' has been detected".format(
                match['strided_slice'].id))
            # remove inputs to Strided Slice so it has just one input with data so we can use 'remove_op_node' function
            graph.remove_edge(match['strided_slice'].in_node(1).id,
                              match['strided_slice'].id)
            graph.remove_edge(match['strided_slice'].in_node(2).id,
                              match['strided_slice'].id)
            if len(match['strided_slice'].in_nodes()) > 3:
                graph.remove_edge(match['strided_slice'].in_node(3).id,
                                  match['strided_slice'].id)

            remove_op_node_with_data_node(graph, match['strided_slice'])