Exemplo n.º 1
0
 def log_axes_to_spatial_first_order(node, graph):
     snpe_converter_utils.log_debug(
         code_to_message.get_debugging_message(
             "DEBUG_AXES_TO_SPATIAL_FIRST_ORDER_ENTRY")(node.op.name))
     for input_name in node.input_names:
         snpe_converter_utils.log_debug(
             code_to_message.get_debugging_message(
                 "DEBUG_AXES_TO_SPATIAL_FIRST_ORDER_INPUT_SIZE")(
                     input_name, str(graph.get_buffer(input_name).shape)))
Exemplo n.º 2
0
    def squash_scale(graph):
        def validate_node(nodes_tuple):
            sum_node = nodes_tuple[0]
            if hasattr(sum_node.op, 'bias'):
                input_buffer_ = graph.get_input_buffers(sum_node)[0]
                prev_ = input_buffer_.producer
                log_assert(
                    hasattr(prev_.op, 'bias'),
                    code_to_message.get_error_message(
                        "ERROR_ADD_BIAS_PREV_NO_BIAS")(sum_node.op.name,
                                                       prev_.op.name,
                                                       prev_.op.type))
                return True
            return False

        sequence = [("elementwise_sum", (), ())]
        matched_node_list = graph.get_matched_nodes(sequence,
                                                    validator=validate_node)

        for node_tuple in matched_node_list:
            node = node_tuple[0]
            input_buffer = graph.get_input_buffers(node)[0]
            prev = input_buffer.producer
            prev.op.bias += node.op.bias
            graph.squash(node, input_buffer.name)
            log_debug2(
                code_to_message.get_debugging_message(
                    "DEBUG_ELEMENTWISESUM_SQUASH")(node.op.name, prev.op.name,
                                                   prev.op.type))
Exemplo n.º 3
0
 def fetch(self, *keys, **kwargs):
     ret = []
     # Prunable indicates whether the weights have been consumed in such a way as to
     # allow pruning of the node (eg Const ops that contain weights are consumed by
     # Conv/FC/etc and thus can be pruned from the network. Const ops that are inputs
     # to a node cannot
     consumed = kwargs.get('prunable', True)
     for key in keys:
         key = str(key)
         log_debug(
             code_to_message.get_debugging_message(
                 "DEBUG_RETRIEVE_WEIGHTS, key"))
         if key not in self.weight_map:
             raise KeyError(
                 code_to_message.get_error_message(
                     "ERROR_WEIGHTS_MISSING_KEY")(key))
         self.weight_map[key].consumed = consumed
         # Explicitly copy the data so if later ops modify it, the original data remains intact
         ret.append(
             numpy.require(self.weight_map[key].weights.copy(),
                           dtype=numpy.float32))
     if len(ret) == 1:
         return ret[0]
     else:
         return ret
Exemplo n.º 4
0
    def squash_batchnorm(graph):
        sequence = [("fully_connected", (), ()), ("batchnorm", (), ())]

        matched_node_list = graph.get_matched_nodes(sequence)

        for node_tuple in matched_node_list:
            # sanity check
            log_assert(
                len(node_tuple) == len(sequence),
                "ERROR: Pattern matching for squash batchnorm returned extra nodes. Got {} nodes, Expected {}.",
                len(node_tuple), len(sequence))

            fc_node = node_tuple[0]
            bn_node = node_tuple[1]
            bn_input_buffer = graph.get_input_buffers(bn_node)[0]

            weights_list = []
            if bn_input_buffer.axis_format == AxisTracker.AxisFormat.NCS:
                # The FC weights are not yet transposed as that happens in axes_to_spatial_first later,
                # so we need to transpose for BN weight broadcasting and then revert
                for weights in fc_node.op.weights_list:
                    weights = numpy.transpose(weights, (2, 3, 1, 0))
                    weights = (weights * bn_node.op.weights)
                    weights_list.append(numpy.transpose(weights, (3, 2, 0, 1)))
            else:
                weights_list = [(weights * bn_node.op.weights)
                                for weights in fc_node.op.weights_list]

            fc_node.op.weights = weights_list
            fc_node.op.bias = fc_node.op.bias * bn_node.op.weights + bn_node.op.bias
            graph.squash(bn_node, bn_input_buffer.name)
            log_debug2(
                code_to_message.get_debugging_message("DEBUG_BATCHNORM_SQUASH")
                (bn_node.op.name, fc_node.op.type, fc_node.op.name))
Exemplo n.º 5
0
    def squash_scale(graph):
        def validate_node(nodes_tuple):
            prod_node = nodes_tuple[0]
            if hasattr(prod_node.op, 'weights'):
                input_buffer_ = graph.get_input_buffers(prod_node)[0]
                prev_ = input_buffer_.producer
                log_assert(
                    prev_.op.type == op_adapter.BatchnormOp.TRANSLATION_KEY,
                    code_to_message.get_error_message(
                        "ERROR_MUL_SCALE_PREV_NOT_BATCHNORM")(prev_.op.name,
                                                              prev_.op.type))
                return True
            return False

        sequence = [("elementwise_product", (), ())]
        matched_node_list = graph.get_matched_nodes(sequence,
                                                    validator=validate_node)

        for node_tuple in matched_node_list:
            node = node_tuple[0]
            input_buffer = graph.get_input_buffers(node)[0]
            prev = input_buffer.producer
            weights = node.op.weights
            prev.op.weights *= weights
            prev.op.bias *= weights
            graph.squash(node, input_buffer.name)
            log_debug2(
                code_to_message.get_debugging_message(
                    "DEBUG_ELEMENTWISEPRODUCT_SQUASH")(node.op.name,
                                                       prev.op.name,
                                                       prev.op.type))
Exemplo n.º 6
0
 def remove_noop(node, graph):
     # Prune this node if it's an input to a weight layer and was used internally
     if graph.weights.consumed(node.output_names[0]):
         log_debug(
             code_to_message.get_debugging_message("DEBUG_CONSTANT_PRUNED")(
                 node.output_names[0]))
         graph.prune(node)
Exemplo n.º 7
0
 def axes_to_snpe_order(self, node, graph):
     log_debug(
         code_to_message.get_debugging_message(
             "DEBUG_AXES_TO_SNPE_ORDER_ENTRY")(node.op.name))
     super(OptimizeChannelShuffleTranslation,
           self).axes_to_spatial_first_order(node, graph)
     for buf in graph.get_input_buffers(node):
         log_debug("input {} {} {}", buf.name, buf.axis_format, buf.shape)
     for buf in graph.get_output_buffers(node):
         log_debug("output {} {} {}", buf.name, buf.axis_format, buf.shape)
Exemplo n.º 8
0
    def fold_concats(graph):
        def validate_concat_axis(nodes_tuple):
            concat_node_ = nodes_tuple[0]
            concat_node_input_bufs_ = graph.get_input_buffers(concat_node_)
            for buf_ in concat_node_input_bufs_:
                if buf_.producer.op.type == op_adapter.ConcatOp.TRANSLATION_KEY:
                    prev_concat_node_ = buf_.producer
                    # only fold concats with same axis
                    if prev_concat_node_.op.axis != concat_node_.op.axis:
                        log_debug2(
                            "Found concat node({}) with a concat input, but axis does not match for input ({}), "
                            "{} != {} ", concat_node_.op.name,
                            prev_concat_node_.op.name,
                            prev_concat_node_.op.axis, concat_node_.op.axis)
                        return False

            return True

        sequence = [("concatenation", ("FLEXIBLE_NUM_BUFS", [("concatenation",
                                                              "ANY")]), ())]
        matched_node_list = graph.get_matched_nodes(
            sequence, validator=validate_concat_axis)

        for node_tuple in matched_node_list:
            concat_node = node_tuple[0]
            concat_node_input_bufs = graph.get_input_buffers(concat_node)

            for buf in concat_node_input_bufs:
                if buf.producer.op.type == op_adapter.ConcatOp.TRANSLATION_KEY:
                    prev_concat_buf = buf  # for readability
                    prev_concat_node = prev_concat_buf.producer

                    # remove prev concat as input from current concat and replace with prev concat's input names
                    prev_concat_inputs = prev_concat_node.input_names
                    idx = concat_node.input_names.index(prev_concat_buf.name)
                    concat_node.input_names.remove(prev_concat_buf.name)
                    concat_node.input_names[
                        idx:
                        idx] = prev_concat_inputs  # extend the inputs in the same index as prev concat

                    prev_concat_buf.consumers.remove(concat_node)

                    # we can prune the prev concat node if the current concat was the only consumer.
                    if len(prev_concat_buf.consumers) == 0:
                        graph.prune(prev_concat_node)

                    # remove prev concat as consumer for prev concat's input bufs and replace with current concat
                    for input_name in prev_concat_inputs:
                        input_buf = graph.get_buffer(input_name)
                        input_buf.consumers.add(concat_node)

                    log_debug2(
                        code_to_message.get_debugging_message(
                            "DEBUG_CONCAT_FOLD")(prev_concat_node.op.name,
                                                 concat_node.op.name))
Exemplo n.º 9
0
    def squash_batchnorm(graph):
        def validate_input_rank(nodes_tuple):
            bn_node_ = nodes_tuple[1]
            bn_input_buffer_ = graph.get_input_buffers(bn_node_)[0]
            return bn_node_.op.type == op_adapter.BatchnormOp.TRANSLATION_KEY and bn_input_buffer_.rank(
            ) == 4

        sequence = [("convolution", (), ()), ("batchnorm", (), ())]
        matched_node_list = graph.get_matched_nodes(
            sequence, validator=validate_input_rank)

        for node_tuple in matched_node_list:
            # sanity check
            log_assert(
                len(node_tuple) == len(sequence),
                "ERROR: Pattern matching for squash batchnorm returned extra nodes. Got {} nodes, Expected {}.",
                len(node_tuple), len(sequence))

            conv_node = node_tuple[0]
            bn_node = node_tuple[1]
            bn_input_buffer = graph.get_input_buffers(bn_node)[0]

            if bn_input_buffer.axis_format == AxisTracker.AxisFormat.NCS:
                # The Conv weights are not yet transposed as that happens in axes_to_spatial_first later,
                # so we need to transpose for BN weight broadcasting and then revert
                weights = numpy.transpose(conv_node.op.weights, (2, 3, 1, 0))
                weights = (weights * bn_node.op.weights)
                weights = numpy.transpose(weights, (3, 2, 0, 1))
            else:
                weights = (conv_node.op.weights * bn_node.op.weights)
            conv_node.op.weights = weights
            conv_node.op.bias = conv_node.op.bias * bn_node.op.weights + bn_node.op.bias
            graph.squash(bn_node, bn_input_buffer.name)
            log_debug2(
                code_to_message.get_debugging_message("DEBUG_BATCHNORM_SQUASH")
                (bn_node.op.name, conv_node.op.type, conv_node.op.name))
Exemplo n.º 10
0
    def match_channelshuffle(graph):
        def is_valid_channelshuffle(nodes_tuple):
            def check_for_valid_reshape_1(node):
                input_buffer = graph.get_input_buffers(node)[0]
                output_buffer = graph.get_output_buffers(node)[0]
                reshape_1_input_shape = input_buffer.shape
                reshape_1_output_shape = output_buffer.shape

                return (
                    len(reshape_1_input_shape) == 4
                    and len(reshape_1_output_shape) == 5
                    and reshape_1_input_shape[0] == reshape_1_output_shape[0]
                    and reshape_1_input_shape[2] == reshape_1_output_shape[3]
                    and reshape_1_input_shape[3] == reshape_1_output_shape[4])

            def check_for_valid_permute(node):
                # Assuming the input shape is N[GC']HW
                return node.op.type == op_adapter.PermuteOp.TRANSLATION_KEY and node.op.order == [
                    0, 2, 1, 3, 4
                ]

            def check_for_valid_reshape_2(node):
                input_buffer = graph.get_input_buffers(node)[0]
                output_buffer = graph.get_output_buffers(node)[0]
                reshape_2_input_shape = input_buffer.shape
                reshape_2_output_shape = output_buffer.shape

                return (
                    len(reshape_2_input_shape) == 5
                    and len(reshape_2_output_shape) == 4
                    and reshape_2_input_shape[0] == reshape_2_output_shape[0]
                    and reshape_2_input_shape[3] == reshape_2_output_shape[2]
                    and reshape_2_input_shape[4] == reshape_2_output_shape[3])

            first_, second_, third_ = nodes_tuple
            input_shape_ = graph.get_input_buffers(first_)[0].shape
            output_shape_ = graph.get_output_buffers(third_)[0].shape

            return ((output_shape_ == input_shape_)
                    and check_for_valid_reshape_1(first_)
                    and check_for_valid_permute(second_)
                    and check_for_valid_reshape_2(third_))

        sequence = [("reshape", (), ("MATCH_NUM_BUFS", [("permute", "ALL")])),
                    ("permute", (), ("MATCH_NUM_BUFS", [("reshape", "ALL")])),
                    ("reshape", (), ())]
        matched_node_list = graph.get_matched_nodes(
            sequence, validator=is_valid_channelshuffle)

        for node_tuple in matched_node_list:

            # ChannelShuffle Op found,
            # Squash Permute and 2nd Reshape Op and
            # Replace 1st ReshapeOp with ShuffleOp
            first, second, third = node_tuple
            third_input_buffer = graph.get_input_buffers(third)[0]
            graph.squash(third, third_input_buffer.name)

            second_input_buffer = graph.get_input_buffers(second)[0]
            graph.squash(second, second_input_buffer.name)

            output_shape = first.op.output_shape
            # Assuming the shape is N[GC']HW
            groups = output_shape[1]
            shuffle_op = op_adapter.ChannelShuffleOp(None, groups=groups)
            shuffle_op.name = graph.naming_policy.get_op_name(shuffle_op)
            graph.replace(first.op, shuffle_op)
            log_debug2(
                code_to_message.get_debugging_message(
                    "DEBUG_CHANNEL_SHUFFLE_REPLACE")(first.op.name,
                                                     second.op.name,
                                                     third.op.name,
                                                     shuffle_op.name))