def log_axes_to_spatial_first_order(node, graph): snpe_converter_utils.log_debug( code_to_message.get_debugging_message( "DEBUG_AXES_TO_SPATIAL_FIRST_ORDER_ENTRY")(node.op.name)) for input_name in node.input_names: snpe_converter_utils.log_debug( code_to_message.get_debugging_message( "DEBUG_AXES_TO_SPATIAL_FIRST_ORDER_INPUT_SIZE")( input_name, str(graph.get_buffer(input_name).shape)))
def squash_scale(graph): def validate_node(nodes_tuple): sum_node = nodes_tuple[0] if hasattr(sum_node.op, 'bias'): input_buffer_ = graph.get_input_buffers(sum_node)[0] prev_ = input_buffer_.producer log_assert( hasattr(prev_.op, 'bias'), code_to_message.get_error_message( "ERROR_ADD_BIAS_PREV_NO_BIAS")(sum_node.op.name, prev_.op.name, prev_.op.type)) return True return False sequence = [("elementwise_sum", (), ())] matched_node_list = graph.get_matched_nodes(sequence, validator=validate_node) for node_tuple in matched_node_list: node = node_tuple[0] input_buffer = graph.get_input_buffers(node)[0] prev = input_buffer.producer prev.op.bias += node.op.bias graph.squash(node, input_buffer.name) log_debug2( code_to_message.get_debugging_message( "DEBUG_ELEMENTWISESUM_SQUASH")(node.op.name, prev.op.name, prev.op.type))
def fetch(self, *keys, **kwargs): ret = [] # Prunable indicates whether the weights have been consumed in such a way as to # allow pruning of the node (eg Const ops that contain weights are consumed by # Conv/FC/etc and thus can be pruned from the network. Const ops that are inputs # to a node cannot consumed = kwargs.get('prunable', True) for key in keys: key = str(key) log_debug( code_to_message.get_debugging_message( "DEBUG_RETRIEVE_WEIGHTS, key")) if key not in self.weight_map: raise KeyError( code_to_message.get_error_message( "ERROR_WEIGHTS_MISSING_KEY")(key)) self.weight_map[key].consumed = consumed # Explicitly copy the data so if later ops modify it, the original data remains intact ret.append( numpy.require(self.weight_map[key].weights.copy(), dtype=numpy.float32)) if len(ret) == 1: return ret[0] else: return ret
def squash_batchnorm(graph): sequence = [("fully_connected", (), ()), ("batchnorm", (), ())] matched_node_list = graph.get_matched_nodes(sequence) for node_tuple in matched_node_list: # sanity check log_assert( len(node_tuple) == len(sequence), "ERROR: Pattern matching for squash batchnorm returned extra nodes. Got {} nodes, Expected {}.", len(node_tuple), len(sequence)) fc_node = node_tuple[0] bn_node = node_tuple[1] bn_input_buffer = graph.get_input_buffers(bn_node)[0] weights_list = [] if bn_input_buffer.axis_format == AxisTracker.AxisFormat.NCS: # The FC weights are not yet transposed as that happens in axes_to_spatial_first later, # so we need to transpose for BN weight broadcasting and then revert for weights in fc_node.op.weights_list: weights = numpy.transpose(weights, (2, 3, 1, 0)) weights = (weights * bn_node.op.weights) weights_list.append(numpy.transpose(weights, (3, 2, 0, 1))) else: weights_list = [(weights * bn_node.op.weights) for weights in fc_node.op.weights_list] fc_node.op.weights = weights_list fc_node.op.bias = fc_node.op.bias * bn_node.op.weights + bn_node.op.bias graph.squash(bn_node, bn_input_buffer.name) log_debug2( code_to_message.get_debugging_message("DEBUG_BATCHNORM_SQUASH") (bn_node.op.name, fc_node.op.type, fc_node.op.name))
def squash_scale(graph): def validate_node(nodes_tuple): prod_node = nodes_tuple[0] if hasattr(prod_node.op, 'weights'): input_buffer_ = graph.get_input_buffers(prod_node)[0] prev_ = input_buffer_.producer log_assert( prev_.op.type == op_adapter.BatchnormOp.TRANSLATION_KEY, code_to_message.get_error_message( "ERROR_MUL_SCALE_PREV_NOT_BATCHNORM")(prev_.op.name, prev_.op.type)) return True return False sequence = [("elementwise_product", (), ())] matched_node_list = graph.get_matched_nodes(sequence, validator=validate_node) for node_tuple in matched_node_list: node = node_tuple[0] input_buffer = graph.get_input_buffers(node)[0] prev = input_buffer.producer weights = node.op.weights prev.op.weights *= weights prev.op.bias *= weights graph.squash(node, input_buffer.name) log_debug2( code_to_message.get_debugging_message( "DEBUG_ELEMENTWISEPRODUCT_SQUASH")(node.op.name, prev.op.name, prev.op.type))
def remove_noop(node, graph): # Prune this node if it's an input to a weight layer and was used internally if graph.weights.consumed(node.output_names[0]): log_debug( code_to_message.get_debugging_message("DEBUG_CONSTANT_PRUNED")( node.output_names[0])) graph.prune(node)
def axes_to_snpe_order(self, node, graph): log_debug( code_to_message.get_debugging_message( "DEBUG_AXES_TO_SNPE_ORDER_ENTRY")(node.op.name)) super(OptimizeChannelShuffleTranslation, self).axes_to_spatial_first_order(node, graph) for buf in graph.get_input_buffers(node): log_debug("input {} {} {}", buf.name, buf.axis_format, buf.shape) for buf in graph.get_output_buffers(node): log_debug("output {} {} {}", buf.name, buf.axis_format, buf.shape)
def fold_concats(graph): def validate_concat_axis(nodes_tuple): concat_node_ = nodes_tuple[0] concat_node_input_bufs_ = graph.get_input_buffers(concat_node_) for buf_ in concat_node_input_bufs_: if buf_.producer.op.type == op_adapter.ConcatOp.TRANSLATION_KEY: prev_concat_node_ = buf_.producer # only fold concats with same axis if prev_concat_node_.op.axis != concat_node_.op.axis: log_debug2( "Found concat node({}) with a concat input, but axis does not match for input ({}), " "{} != {} ", concat_node_.op.name, prev_concat_node_.op.name, prev_concat_node_.op.axis, concat_node_.op.axis) return False return True sequence = [("concatenation", ("FLEXIBLE_NUM_BUFS", [("concatenation", "ANY")]), ())] matched_node_list = graph.get_matched_nodes( sequence, validator=validate_concat_axis) for node_tuple in matched_node_list: concat_node = node_tuple[0] concat_node_input_bufs = graph.get_input_buffers(concat_node) for buf in concat_node_input_bufs: if buf.producer.op.type == op_adapter.ConcatOp.TRANSLATION_KEY: prev_concat_buf = buf # for readability prev_concat_node = prev_concat_buf.producer # remove prev concat as input from current concat and replace with prev concat's input names prev_concat_inputs = prev_concat_node.input_names idx = concat_node.input_names.index(prev_concat_buf.name) concat_node.input_names.remove(prev_concat_buf.name) concat_node.input_names[ idx: idx] = prev_concat_inputs # extend the inputs in the same index as prev concat prev_concat_buf.consumers.remove(concat_node) # we can prune the prev concat node if the current concat was the only consumer. if len(prev_concat_buf.consumers) == 0: graph.prune(prev_concat_node) # remove prev concat as consumer for prev concat's input bufs and replace with current concat for input_name in prev_concat_inputs: input_buf = graph.get_buffer(input_name) input_buf.consumers.add(concat_node) log_debug2( code_to_message.get_debugging_message( "DEBUG_CONCAT_FOLD")(prev_concat_node.op.name, concat_node.op.name))
def squash_batchnorm(graph): def validate_input_rank(nodes_tuple): bn_node_ = nodes_tuple[1] bn_input_buffer_ = graph.get_input_buffers(bn_node_)[0] return bn_node_.op.type == op_adapter.BatchnormOp.TRANSLATION_KEY and bn_input_buffer_.rank( ) == 4 sequence = [("convolution", (), ()), ("batchnorm", (), ())] matched_node_list = graph.get_matched_nodes( sequence, validator=validate_input_rank) for node_tuple in matched_node_list: # sanity check log_assert( len(node_tuple) == len(sequence), "ERROR: Pattern matching for squash batchnorm returned extra nodes. Got {} nodes, Expected {}.", len(node_tuple), len(sequence)) conv_node = node_tuple[0] bn_node = node_tuple[1] bn_input_buffer = graph.get_input_buffers(bn_node)[0] if bn_input_buffer.axis_format == AxisTracker.AxisFormat.NCS: # The Conv weights are not yet transposed as that happens in axes_to_spatial_first later, # so we need to transpose for BN weight broadcasting and then revert weights = numpy.transpose(conv_node.op.weights, (2, 3, 1, 0)) weights = (weights * bn_node.op.weights) weights = numpy.transpose(weights, (3, 2, 0, 1)) else: weights = (conv_node.op.weights * bn_node.op.weights) conv_node.op.weights = weights conv_node.op.bias = conv_node.op.bias * bn_node.op.weights + bn_node.op.bias graph.squash(bn_node, bn_input_buffer.name) log_debug2( code_to_message.get_debugging_message("DEBUG_BATCHNORM_SQUASH") (bn_node.op.name, conv_node.op.type, conv_node.op.name))
def match_channelshuffle(graph): def is_valid_channelshuffle(nodes_tuple): def check_for_valid_reshape_1(node): input_buffer = graph.get_input_buffers(node)[0] output_buffer = graph.get_output_buffers(node)[0] reshape_1_input_shape = input_buffer.shape reshape_1_output_shape = output_buffer.shape return ( len(reshape_1_input_shape) == 4 and len(reshape_1_output_shape) == 5 and reshape_1_input_shape[0] == reshape_1_output_shape[0] and reshape_1_input_shape[2] == reshape_1_output_shape[3] and reshape_1_input_shape[3] == reshape_1_output_shape[4]) def check_for_valid_permute(node): # Assuming the input shape is N[GC']HW return node.op.type == op_adapter.PermuteOp.TRANSLATION_KEY and node.op.order == [ 0, 2, 1, 3, 4 ] def check_for_valid_reshape_2(node): input_buffer = graph.get_input_buffers(node)[0] output_buffer = graph.get_output_buffers(node)[0] reshape_2_input_shape = input_buffer.shape reshape_2_output_shape = output_buffer.shape return ( len(reshape_2_input_shape) == 5 and len(reshape_2_output_shape) == 4 and reshape_2_input_shape[0] == reshape_2_output_shape[0] and reshape_2_input_shape[3] == reshape_2_output_shape[2] and reshape_2_input_shape[4] == reshape_2_output_shape[3]) first_, second_, third_ = nodes_tuple input_shape_ = graph.get_input_buffers(first_)[0].shape output_shape_ = graph.get_output_buffers(third_)[0].shape return ((output_shape_ == input_shape_) and check_for_valid_reshape_1(first_) and check_for_valid_permute(second_) and check_for_valid_reshape_2(third_)) sequence = [("reshape", (), ("MATCH_NUM_BUFS", [("permute", "ALL")])), ("permute", (), ("MATCH_NUM_BUFS", [("reshape", "ALL")])), ("reshape", (), ())] matched_node_list = graph.get_matched_nodes( sequence, validator=is_valid_channelshuffle) for node_tuple in matched_node_list: # ChannelShuffle Op found, # Squash Permute and 2nd Reshape Op and # Replace 1st ReshapeOp with ShuffleOp first, second, third = node_tuple third_input_buffer = graph.get_input_buffers(third)[0] graph.squash(third, third_input_buffer.name) second_input_buffer = graph.get_input_buffers(second)[0] graph.squash(second, second_input_buffer.name) output_shape = first.op.output_shape # Assuming the shape is N[GC']HW groups = output_shape[1] shuffle_op = op_adapter.ChannelShuffleOp(None, groups=groups) shuffle_op.name = graph.naming_policy.get_op_name(shuffle_op) graph.replace(first.op, shuffle_op) log_debug2( code_to_message.get_debugging_message( "DEBUG_CHANNEL_SHUFFLE_REPLACE")(first.op.name, second.op.name, third.op.name, shuffle_op.name))