def add_nop_bn_after(g, value_names): """Add do-nothing BatchNormalization nodes after the given value info. It will\\ take the given names as the inputs of the new node and replace the inputs\\ of the following nodes. :param g: the graph\\ :param value_names: a list of string which are the names of value_info. """ for value_name in value_names: # Find the value first value = helper.find_value_by_name(g, value_name) if value is None: value = helper.find_input_by_name(g, value_name) if value is None: value = helper.find_output_by_name(g, value_name) if value is None: print("Cannot find an value_info named {}".format(value_name)) continue # Get the channel number from value info shape = helper.get_shape_from_value_info(value) channel = shape[1] # Construct 4 weights node_name = value_name + "_nop_bn" ones = [1.0] * channel zeros = [0.0] * channel scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) var_node = helper.list_to_constant(node_name + "_var", [channel], ones) # Construct BN node bn_node = onnx.helper.make_node("BatchNormalization", [ value_name, scale_node.output[0], bias_node.output[0], mean_node.output[0], var_node.output[0] ], [node_name], name=node_name) # Reconnect the graph following_nodes = helper.find_following_nodes_by_input_value_name( g, value_name) if len(following_nodes) > 0: for following_node in following_nodes: replace_node_input(following_node, value_name, node_name) else: # If the node is the output, replace the output with the previous input. new_value = onnx.helper.make_tensor_value_info( node_name, value.type.tensor_type.elem_type, shape) output_values = [] while len(g.output): output_values.append(g.output.pop()) while output_values: output_value = output_values.pop() if output_value.name == value_name: g.output.extend([new_value]) else: g.output.extend([output_value]) # Add node to the graph g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) topological_sort(g)
def add_bn_after(prev_node): # Get the channel number from value info value_name = prev_node.output[0] value = helper.find_value_by_name(g, value_name) shape = helper.get_shape_from_value_info(value) channel = shape[1] # Construct 4 weights node_name = value_name + "_nop_bn" ones = [1.0] * channel zeros = [0.0] * channel scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) var_node = helper.list_to_constant(node_name + "_var", [channel], ones) # Construct BN node bn_node = onnx.helper.make_node( "BatchNormalization", [value_name, scale_node.output[0], bias_node.output[0], mean_node.output[0], var_node.output[0]], [node_name], name = node_name ) # Reconnect the graph replace_node_input(n, value_name, node_name) # Add node to the graph g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node])
def replace_depthwise_1x1_with_bn(g): """Replace 1x1 DepthwiseConv node into BN node if applicable. :param g: the onnx graph """ node_to_remove = [] for node in g.node: # Check op_type if node.op_type != 'Conv': continue # Check attributes attr_map = {attr.name: attr for attr in node.attribute} if "group" not in attr_map or attr_map["group"].i == 1: continue if attr_map["kernel_shape"].ints[0] != 1 or attr_map["kernel_shape"].ints[1] != 1: continue if "pads" in attr_map and sum(attr_map["pads"].ints) != 0: continue # Check scale scale_node = helper.find_node_by_output_name(g, node.input[1]) if scale_node is None or scale_node.attribute[0].t.dims[1] != 1: continue scale_node.attribute[0].t.dims.pop() scale_node.attribute[0].t.dims.pop() scale_node.attribute[0].t.dims.pop() scale_info = helper.find_value_by_name(g, node.input[1]) if scale_info is not None: scale_info.type.tensor_type.shape.dim.pop() scale_info.type.tensor_type.shape.dim.pop() scale_info.type.tensor_type.shape.dim.pop() # Check bias if len(node.input) == 3: bias_name = node.input[2] else: bias_name = node.name + "_bias" bias_node = helper.list_to_constant(bias_name, [attr_map["group"].i], [0.0] * attr_map["group"].i) g.node.extend([bias_node]) # Construct mean and vars mean_name = node.name + "_mean" mean_node = helper.list_to_constant(mean_name, [attr_map["group"].i], [0.0] * attr_map["group"].i) var_name = node.name + "_var" var_node = helper.list_to_constant(var_name, [attr_map["group"].i], [1.0] * attr_map["group"].i) g.node.extend([mean_node, var_node]) # Convert bn_node = onnx.helper.make_node( op_type='BatchNormalization', inputs=[node.input[0], node.input[1], bias_name, mean_name, var_name], outputs=node.output, name=node.name, epsilon=0.00001, momentum=0.9 ) g.node.extend([bn_node]) node_to_remove.append(node) for node in node_to_remove: g.node.remove(node) topological_sort(g)
def add_bn_on_skip_branch(g): for n in g.node: # Find merge node (Add) if n.op_type != 'Add': continue if len(n.input) != 2: continue # TODO: Still need to consider more cases # Check if skip branch exist input_node_a = helper.find_node_by_output_name(g, n.input[0]) output_of_input_node_a = helper.find_nodes_by_input_name( g, input_node_a.output[0]) input_node_b = helper.find_node_by_output_name(g, n.input[1]) output_of_input_node_b = helper.find_nodes_by_input_name( g, input_node_b.output[0]) if len(output_of_input_node_a) == 1 and len( output_of_input_node_b) == 1: continue if len(output_of_input_node_a) == 2: split_node = input_node_a elif len(output_of_input_node_b) == 2: split_node = input_node_b else: continue # Get the channel number from value info value_name = split_node.output[0] value = helper.find_value_by_name(g, value_name) shape = helper.get_shape_from_value_info(value) channel = shape[1] # Construct 4 weights node_name = value_name + "_nop_bn" ones = [1.0] * channel zeros = [0.0] * channel scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) var_node = helper.list_to_constant(node_name + "_var", [channel], ones) # Construct BN node bn_node = onnx.helper.make_node("BatchNormalization", [ value_name, scale_node.output[0], bias_node.output[0], mean_node.output[0], var_node.output[0] ], [node_name], name=node_name) # Reconnect the graph replace_node_input(n, value_name, node_name) # Add node to the graph g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) topological_sort(g)
def replace_Squeeze_with_Reshape(g): """ Replace Squeeze nodes with Reshape node. :param g: the input graph """ node_to_remove = [] for node in g.node: # Find Squeeze node if node.op_type != 'Squeeze': continue # Get the shape and Construct the shape output_value = helper.find_value_by_name(g, node.output[0]) if output_value is None: output_value = helper.find_output_by_name(g, node.output[0]) if output_value is None: raise RuntimeError("Cannot get shape for Squeeze") shape = [ dim.dim_value for dim in output_value.type.tensor_type.shape.dim ] const_node = helper.list_to_constant(node.name + "_shape", [len(shape)], shape) # Construct the Reshape layer with same input, output and name. new_node = onnx.helper.make_node("Reshape", [node.input[0], node.name + "_shape"], node.output, name=node.name) # Append constructed nodes and append old node to remove_list g.node.extend([const_node, new_node]) node_to_remove.append(node) # Remove old nodes for node in node_to_remove: g.node.remove(node) # Topological sort topological_sort(g)
def polish_RESIZE_input_param_node(g, resize_node_name): resize_node = helper.find_node_by_output_name(g, resize_node_name) shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3]) shape_data = helper.constant_to_numpy(shape_data_node).astype(int) # handle 0 batch size which is invalid if shape_data[0] == 0: shape_data[0] = 1 pre_node_output_value_info = helper.find_value_by_name( g, resize_node.input[0]) ori_shape = np.array([ pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value, pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value, pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value, pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value ]) resize_node.input.remove(resize_node.input[3]) resize_scales = np.array(shape_data / ori_shape).astype(float) resize_scale_node = helper.list_to_constant( 'resize_scales_node_' + resize_node.name, resize_scales.shape, resize_scales, data_type=onnx.helper.TensorProto.FLOAT) resize_node.input[2] = resize_scale_node.name g.node.extend([resize_scale_node]) other.topological_sort(g)
def add_nop_conv_after(g, value_names): """Add do-nothing depthwise Conv nodes after the given value info. It will\\ take the given names as the inputs of the new node and replace the inputs\\ of the following nodes. :param g: the graph\\ :param value_names: a list of string which are the names of value_info. """ for value_name in value_names: # Find the value first value = helper.find_value_by_name(g, value_name) if value is None: value = helper.find_input_by_name(g, value_name) if value is None: value = helper.find_output_by_name(g, value_name) if value is None: print("Cannot find an value_info named {}".format(value_name)) continue # Get the channel number from value info shape = helper.get_shape_from_value_info(value) channel = shape[1] # Construct 4 weights node_name = value_name + "_nop_conv" ones = [1.0] * channel weight_node = helper.list_to_constant(node_name + "_weight", [channel, 1, 1, 1], ones) # Construct BN node conv_node = onnx.helper.make_node("Conv", [value_name, weight_node.output[0]], [node_name], name=node_name, dilations=[1, 1], group=channel, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1]) # Reconnect the graph following_nodes = helper.find_following_nodes_by_input_value_name( g, value_name) if len(following_nodes) > 0: for following_node in following_nodes: replace_node_input(following_node, value_name, node_name) else: # If the node is the output, replace the output with the previous input. new_value = onnx.helper.make_tensor_value_info( node_name, value.type.tensor_type.elem_type, shape) output_values = [] while len(g.output): output_values.append(g.output.pop()) while output_values: output_value = output_values.pop() if output_value.name == value_name: g.output.extend([new_value]) else: g.output.extend([output_value]) # Add node to the graph g.node.extend([conv_node, weight_node]) topological_sort(g)
def replace_shape_with_constant(g): """Replace Shape with Constant.\\ This is the first step of reshape constant folding. :param g: the input graph\\ :return: if anything modified, return true. """ node_to_remove = [] for node in g.node: # Find a Shape if node.op_type != 'Shape': continue # Check its input input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: input_value = helper.find_input_by_name(g, node.input[0]) if input_value is None or len( input_value.type.tensor_type.shape.dim) == 0: continue # Check for case where dimension could be 0 or -1 tmp = True for d in input_value.type.tensor_type.shape.dim: tmp = tmp and (d.dim_value > 0) if not tmp: continue # Repalce it input_shape = [ d.dim_value for d in input_value.type.tensor_type.shape.dim ] node_name = node.output[0] new_node = helper.list_to_constant(node_name, [len(input_shape)], input_shape) g.node.extend([new_node]) node_to_remove.append(node) # if the input value_info is not used by other node # delete this input value_info val_info_used = sum( [input_value.name in node.input for node in g.node]) if val_info_used == 1: g.value_info.remove(input_value) replaced = True if len(node_to_remove) > 0 else False for node in node_to_remove: g.node.remove(node) topological_sort(g) return replaced
def replace_ConstantOfShape_with_constant(g): """Replace Shape with Constant.\\ This is the first step of reshape constant folding. :param g: the input graph\\ :return: if anything modified, return true. """ node_to_remove = [] for node in g.node: # Find a Shape if node.op_type != 'ConstantOfShape': continue # Check input input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: input_value = helper.find_input_by_name(g, node.input[0]) if input_value is None or len( input_value.type.tensor_type.shape.dim) == 0: continue # Replace to constant node pre_node = helper.find_node_by_output_name(g, node.input[0]) _, target_shape = helper.constant_to_list(pre_node) value = helper.get_attribute_by_name(node, 'value').i node_name = node.output[0] new_node = helper.list_to_constant(node_name, [target_shape[0]], [value] * target_shape[0]) g.node.extend([new_node]) # remove old node node_to_remove.append(node) # delete value_info val_info_used = sum( [input_value.name in node.input for node in g.node]) if val_info_used == 1: g.value_info.remove(input_value) replaced = True if len(node_to_remove) > 0 else False for node in node_to_remove: g.node.remove(node) topological_sort(g) return replaced
def replace_mul_to_bn(g): """Replace single Mul node with Batchnorm node. :param g: input graph. :return: """ node_to_del = [] for node in g.node: if node.op_type != 'Mul': continue mul_op_node = node # only support one input node if len(mul_op_node.input) != 2: # OP node and value node continue input_op_node_name = mul_op_node.input[0] mul_value_node = helper.find_node_by_output_name( g, mul_op_node.input[1]) if not mul_value_node or mul_value_node.op_type != 'Constant': continue _, previous_node_output_shape = helper.find_size_shape_from_value( helper.find_value_by_name(g, input_op_node_name)) scale_shape, scale_data = helper.constant_to_list(mul_value_node) # only allow 4 dim data input due to the hardware limitation if len(previous_node_output_shape) != 4: continue # channel dimension c_dim = previous_node_output_shape[1] # only allow channelwise mul or const mul if scale_shape != [1, c_dim, 1, 1] and scale_shape != 1: continue ones = [1.0] * c_dim zeros = [0.0] * c_dim muls = scale_data * c_dim bn_name = mul_op_node.output[0] mean_value_node = helper.list_to_constant(bn_name + '_mean', np.array(zeros).shape, zeros) variance_value_node = helper.list_to_constant(bn_name + '_var', np.array(ones).shape, ones) bias_value_node = helper.list_to_constant(bn_name + '_add', np.array(zeros).shape, zeros) new_mul_value_node = helper.list_to_constant(bn_name + '_mul', np.array(muls).shape, muls) bn_node = onnx.helper.make_node('BatchNormalization', [ input_op_node_name, new_mul_value_node.output[0], bias_value_node.output[0], mean_value_node.output[0], variance_value_node.output[0] ], [mul_op_node.output[0]], name=bn_name, epsilon=0.00000001) mid_val_info = helper.find_value_by_name(g, mul_op_node.output[0]) scale_val_info = helper.find_value_by_name(g, mul_value_node.output[0]) g.value_info.remove(mid_val_info) g.value_info.remove(scale_val_info) g.node.extend([bn_node]) g.node.extend([mean_value_node]) g.node.extend([variance_value_node]) g.node.extend([bias_value_node]) g.node.extend([new_mul_value_node]) node_to_del.extend([mul_op_node]) node_to_del.extend([mul_value_node]) while node_to_del: g.node.remove(node_to_del.pop()) topological_sort(g)
def fuse_mul_and_add_into_bn(g): node_to_del = [] for node in g.node: if node.op_type != 'Add': continue add_node = node input_nodes_add = [ helper.find_node_by_output_name(g, input_name) for input_name in add_node.input ] if any([n == None for n in input_nodes_add]): continue mul_node, const_add = None, None for input_node_add in input_nodes_add: if input_node_add.op_type == 'Mul': mul_node = input_node_add elif input_node_add.op_type == 'Constant': const_add = input_node_add else: pass if not mul_node or not const_add: continue data_input_name, const_mul = None, None for input_name in mul_node.input: input_node = helper.find_node_by_output_name(g, input_name) if not input_node: data_input_name = input_name elif input_node.op_type == 'Constant': if not const_mul: const_mul = input_node else: data_input_name = input_name else: data_input_name = input_name if not const_mul: continue scale_shape, scale_data = helper.constant_to_list(const_mul) bais_shape, __ = helper.constant_to_list(const_add) c_dim = len(scale_data) if scale_shape != bais_shape: continue _, previous_node_output_shape = helper.find_size_shape_from_value( helper.find_value_by_name(g, data_input_name)) # only allow 4 dim data input due to the hardware limitation if len(previous_node_output_shape) != 4: continue # check if mul's dim and input channel dimension are matched if previous_node_output_shape[1] != c_dim: continue if scale_shape == [1, c_dim, 1, 1]: # remove all '1' for _ in range(3): const_add.attribute[0].t.dims.remove(1) const_mul.attribute[0].t.dims.remove(1) elif scale_shape == [1, c_dim]: # remove all '1' const_add.attribute[0].t.dims.remove(1) const_mul.attribute[0].t.dims.remove(1) else: continue bn_name = add_node.output[0] const_mean = helper.list_to_constant(bn_name + '_mean', [c_dim], [0.0 for _ in range(c_dim)]) const_var = helper.list_to_constant(bn_name + '_var', [c_dim], [1.0 for _ in range(c_dim)]) bn_node = onnx.helper.make_node( 'BatchNormalization', [data_input_name, const_mul.output[0], const_add.output[0],\ const_mean.output[0], const_var.output[0]], [add_node.output[0]], name=bn_name, epsilon=0.00000001 ) mid_val_info = helper.find_value_by_name(g, mul_node.output[0]) scale_val_info = helper.find_value_by_name(g, const_mul.output[0]) bais_val_info = helper.find_value_by_name(g, const_add.output[0]) g.value_info.remove(mid_val_info) g.value_info.remove(scale_val_info) g.value_info.remove(bais_val_info) new_scale_val_info = onnx.helper.make_tensor_value_info( const_mul.output[0], const_mul.attribute[0].t.data_type, [c_dim]) new_bais_val_info = onnx.helper.make_tensor_value_info( const_add.output[0], const_add.attribute[0].t.data_type, [c_dim]) mean_val_info = onnx.helper.make_tensor_value_info( const_mean.output[0], const_mean.attribute[0].t.data_type, [c_dim]) var_val_info = onnx.helper.make_tensor_value_info( const_var.output[0], const_var.attribute[0].t.data_type, [c_dim]) g.value_info.extend([new_scale_val_info]) g.value_info.extend([new_bais_val_info]) g.value_info.extend([mean_val_info]) g.value_info.extend([var_val_info]) g.node.extend([bn_node]) g.node.extend([const_mean]) g.node.extend([const_var]) node_to_del.extend([mul_node, add_node]) while node_to_del: g.node.remove(node_to_del.pop()) topological_sort(g)
def replace_add_to_bn(g): """Replace single Add node with Batchnorm node. :param g: input graph. :return: """ node_to_del = [] for node in g.node: if node.op_type != 'Add': continue add_op_node = node # only support one input node if len(add_op_node.input) != 2: # OP node and value node continue input_op_node_name = add_op_node.input[0] add_value_node = helper.find_node_by_output_name( g, add_op_node.input[1]) if not add_value_node or add_value_node.op_type != 'Constant': continue prev_shape_value_info = helper.find_value_by_name( g, input_op_node_name) prev_shape_value_info = helper.find_input_by_name( g, input_op_node_name ) if prev_shape_value_info is None else prev_shape_value_info if prev_shape_value_info is None: continue _, previous_node_output_shape = helper.find_size_shape_from_value( prev_shape_value_info) bias_shape, bias_data = helper.constant_to_list(add_value_node) # channel dimension c_dim = previous_node_output_shape[1] if len( previous_node_output_shape) > 1 else 1 # only allow channelwise mul or const mul if bias_shape != [1, c_dim, 1, 1] and bias_shape != 1: continue ones = [1.0] * c_dim zeros = [0.0] * c_dim # If bias is a scaler, expand it. if len(bias_data) == 1: bias = bias_data * c_dim else: bias = bias_data bn_name = add_op_node.output[0] mean_value_node = helper.list_to_constant(bn_name + '_mean', np.array(zeros).shape, zeros) variance_value_node = helper.list_to_constant(bn_name + '_var', np.array(ones).shape, ones) scale_value_node = helper.list_to_constant(bn_name + '_mul', np.array(ones).shape, ones) new_add_value_node = helper.list_to_constant(bn_name + '_add', np.array(bias).shape, bias) bn_node = onnx.helper.make_node('BatchNormalization', [ input_op_node_name, scale_value_node.output[0], new_add_value_node.output[0], mean_value_node.output[0], variance_value_node.output[0] ], [add_op_node.output[0]], name=bn_name, epsilon=0.00000001) add_val_info = helper.find_value_by_name(g, add_value_node.output[0]) g.value_info.remove(add_val_info) g.node.extend([bn_node]) g.node.extend([mean_value_node]) g.node.extend([variance_value_node]) g.node.extend([scale_value_node]) g.node.extend([new_add_value_node]) node_to_del.extend([add_op_node]) node_to_del.extend([add_value_node]) while node_to_del: g.node.remove(node_to_del.pop()) topological_sort(g)
def replace_split_with_slices(g): """Replace split node with slice nodes. :param g: input graph. :return: """ node_to_remove = [] for node in g.node: # Find a Split if node.op_type != 'Split': continue input_value = helper.find_value_by_name(g, node.input[0]) if not input_value: input_value = helper.find_input_by_name(g, node.input[0]) _, shape = helper.find_size_shape_from_value(input_value) if len(shape) == 0: continue output_val_names = list(node.output) axis = 0 split = [] for item in node.attribute: if item.name == 'axis': axis = item.i if item.name == 'split': split = item.ints length = input_value.type.tensor_type.shape.dim[axis].dim_value if split is not []: n_out = len(node.attribute[1].ints) pos = 0 for i in range(n_out): pos += node.attribute[1].ints[i] new_node_name = output_val_names[i] # Construct starts, ends, axes starts_name = new_node_name + '_starts_' + str(i) ends_name = new_node_name + '_ends_' + str(i) axes_name = new_node_name + '_axes_' + str(i) starts_node = helper.list_to_constant( starts_name, (1, ), [int(pos - node.attribute[1].ints[i])]) ends_node = helper.list_to_constant(ends_name, (1, ), [int(pos)]) axes_node = helper.list_to_constant(axes_name, (1, ), [int(axis)]) # Construtc node new_node = onnx.helper.make_node( op_type='Slice', inputs=[node.input[0], starts_name, ends_name, axes_name], outputs=[new_node_name], name=new_node_name) g.node.extend([starts_node, ends_node, axes_node, new_node]) node_to_remove.append(node) else: n_out = len(output_val_names) width = length // n_out for i in range(n_out): new_node_name = output_val_names[i] # Construct starts, ends, axes starts_name = new_node_name + '_starts_' + str(i) ends_name = new_node_name + '_ends_' + str(i) axes_name = new_node_name + '_axes_' + str(i) starts_node = helper.list_to_constant(starts_name, (1, ), [int(i * width)]) ends_node = helper.list_to_constant(ends_name, (1, ), [int((1 + i) * width)]) axes_node = helper.list_to_constant(axes_name, (1, ), [int(axis)]) # Construtc node new_node = onnx.helper.make_node( op_type='Slice', inputs=[node.input[0], starts_name, ends_name, axes_name], outputs=[new_node_name], name=new_node_name) g.node.extend([starts_node, ends_node, axes_node, new_node]) node_to_remove.append(node) for old_node in node_to_remove: g.node.remove(old_node) topological_sort(g)