def add_bn_before_add(g): for n in g.node: # Find merge node (Add) if n.op_type != 'Add': continue if len(n.input) != 2: continue # Get two inputs input_node_a = helper.find_node_by_output_name(g, n.input[0]) input_node_b = helper.find_node_by_output_name(g, n.input[1]) # Skip constant input add if input_node_a is None or input_node_a.op_type == 'Constant': continue if input_node_b is None or input_node_b.op_type == 'Constant': continue def add_bn_after(prev_node): # Get the channel number from value info value_name = prev_node.output[0] value = helper.find_value_by_name(g, value_name) shape = helper.get_shape_from_value_info(value) channel = shape[1] # Construct 4 weights node_name = value_name + "_nop_bn" ones = [1.0] * channel zeros = [0.0] * channel scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) var_node = helper.list_to_constant(node_name + "_var", [channel], ones) # Construct BN node bn_node = onnx.helper.make_node("BatchNormalization", [ value_name, scale_node.output[0], bias_node.output[0], mean_node.output[0], var_node.output[0] ], [node_name], name=node_name, epsilon=0.00000001) # Reconnect the graph replace_node_input(n, value_name, node_name) # Add node to the graph g.node.extend( [bn_node, scale_node, bias_node, mean_node, var_node]) if not input_node_a.op_type == 'BatchNormalization' or len( helper.find_following_nodes_by_input_value_name( g, input_node_a.output[0])) > 1: add_bn_after(input_node_a) if not input_node_b.op_type == 'BatchNormalization' or len( helper.find_following_nodes_by_input_value_name( g, input_node_b.output[0])) > 1: add_bn_after(input_node_b) topological_sort(g)
def duplicate_shared_Flatten(g): """To feed our compiler, bind Flatten with Gemm. If the output of one\\ Flatten goes to two Gemm nodes, duplicate the Flatten. :param g: the graph """ for node in g.node: # Find a Flatten node if node.op_type != 'Flatten': continue # Check Flatten outputs. Get following Gemm output_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) if len(output_nodes) < 2: continue gemm_nodes = [] for output_node in output_nodes: if output_node.op_type == 'Gemm': gemm_nodes.append(output_node) if len(gemm_nodes) < 2: continue # Process all the Gemm nodes except for the first one. for i in range(1, len(gemm_nodes)): # Duplicate new_flatten_name = node.name + "_copy" + str(i) new_flatten_node = onnx.helper.make_node( "Flatten", node.input, [new_flatten_name], name=new_flatten_name, axis=1 ) # Connect new graph replace_node_input(gemm_nodes[i], node.output[0], new_flatten_name) g.node.extend([new_flatten_node]) topological_sort(g)
def add_nop_bn_after(g, value_names): """Add do-nothing BatchNormalization nodes after the given value info. It will\\ take the given names as the inputs of the new node and replace the inputs\\ of the following nodes. :param g: the graph\\ :param value_names: a list of string which are the names of value_info. """ for value_name in value_names: # Find the value first value = helper.find_value_by_name(g, value_name) if value is None: value = helper.find_input_by_name(g, value_name) if value is None: value = helper.find_output_by_name(g, value_name) if value is None: print("Cannot find an value_info named {}".format(value_name)) continue # Get the channel number from value info shape = helper.get_shape_from_value_info(value) channel = shape[1] # Construct 4 weights node_name = value_name + "_nop_bn" ones = [1.0] * channel zeros = [0.0] * channel scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) var_node = helper.list_to_constant(node_name + "_var", [channel], ones) # Construct BN node bn_node = onnx.helper.make_node("BatchNormalization", [ value_name, scale_node.output[0], bias_node.output[0], mean_node.output[0], var_node.output[0] ], [node_name], name=node_name) # Reconnect the graph following_nodes = helper.find_following_nodes_by_input_value_name( g, value_name) if len(following_nodes) > 0: for following_node in following_nodes: replace_node_input(following_node, value_name, node_name) else: # If the node is the output, replace the output with the previous input. new_value = onnx.helper.make_tensor_value_info( node_name, value.type.tensor_type.elem_type, shape) output_values = [] while len(g.output): output_values.append(g.output.pop()) while output_values: output_value = output_values.pop() if output_value.name == value_name: g.output.extend([new_value]) else: g.output.extend([output_value]) # Add node to the graph g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) topological_sort(g)
def change_first_conv_from_bgr_to_rgb(m): """For input channel format BGR model, use this function to change the first conv weight to adapt the input into RGB. :param m: the model proto """ # Check for first node. g = m.graph input_name = g.input[0].name first_nodes = helper.find_following_nodes_by_input_value_name( g, input_name) if len(first_nodes) > 1: return False first_node = first_nodes[0] # Now we have the first node. Check this first node. if first_node.op_type != 'Conv': return False weight_value = helper.find_value_by_name(g, first_node.input[1]) weight_shape = helper.get_shape_from_value_info(weight_value) if weight_shape[1] != 3: return False # Do weight shuffle weight_node = helper.find_node_by_output_name(g, weight_value.name) weight_np = helper.constant_to_numpy(weight_node) b_channel = np.expand_dims(weight_np[:, 0, :, :], axis=1) g_channel = np.expand_dims(weight_np[:, 1, :, :], axis=1) r_channel = np.expand_dims(weight_np[:, 2, :, :], axis=1) new_np = np.concatenate((r_channel, g_channel, b_channel), axis=1) new_node = helper.numpy_to_constant(weight_value.name, new_np) # Replace the weight and topological sort g.node.remove(weight_node) g.node.extend([new_node]) other.topological_sort(g) return True
def rename_all_node_name(g): """ rename all nodes: new_name = old_name + "_kn" :param g: the onnx graph """ for node in g.node: new_node_name = node.name + "_kn" new_node_output0_name = node.output[0] + "_kn" # in order to keep same output node name, skip if it is output node. output_value_info = helper.find_output_by_name(g, node.output[0]) if output_value_info != None: continue # rename the input of all the following nodes following_nodes = helper.find_following_nodes_by_input_value_name( g, node.output[0]) for following_node in following_nodes: replace_node_input(following_node, node.output[0], new_node_output0_name) # rename value info value_info = helper.find_value_by_name(g, node.output[0]) if value_info != None: value_info.name = new_node_output0_name # rename node node.output[0] = new_node_output0_name node.name = new_node_name
def add_nop_conv_after(g, value_names): """Add do-nothing depthwise Conv nodes after the given value info. It will\\ take the given names as the inputs of the new node and replace the inputs\\ of the following nodes. :param g: the graph\\ :param value_names: a list of string which are the names of value_info. """ for value_name in value_names: # Find the value first value = helper.find_value_by_name(g, value_name) if value is None: value = helper.find_input_by_name(g, value_name) if value is None: value = helper.find_output_by_name(g, value_name) if value is None: print("Cannot find an value_info named {}".format(value_name)) continue # Get the channel number from value info shape = helper.get_shape_from_value_info(value) channel = shape[1] # Construct 4 weights node_name = value_name + "_nop_conv" ones = [1.0] * channel weight_node = helper.list_to_constant(node_name + "_weight", [channel, 1, 1, 1], ones) # Construct BN node conv_node = onnx.helper.make_node("Conv", [value_name, weight_node.output[0]], [node_name], name=node_name, dilations=[1, 1], group=channel, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1]) # Reconnect the graph following_nodes = helper.find_following_nodes_by_input_value_name( g, value_name) if len(following_nodes) > 0: for following_node in following_nodes: replace_node_input(following_node, value_name, node_name) else: # If the node is the output, replace the output with the previous input. new_value = onnx.helper.make_tensor_value_info( node_name, value.type.tensor_type.elem_type, shape) output_values = [] while len(g.output): output_values.append(g.output.pop()) while output_values: output_value = output_values.pop() if output_value.name == value_name: g.output.extend([new_value]) else: g.output.extend([output_value]) # Add node to the graph g.node.extend([conv_node, weight_node]) topological_sort(g)
def replace_Sum_with_Adds(g): node_to_del = [] for node in g.node: # Check for sum if node.op_type != 'Sum': continue # Check for input number if len(node.input) == 1: # If input number is 1, delete the sum node. following_nodes = helper.find_following_nodes_by_input_value_name(g, node.output[0]) for following_node in following_nodes: modhelper.replace_node_input(following_node, node.output[0], node.input[0]) node_to_del.append(node) if helper.find_value_by_name(node.output[0]) is not None: g.value_info.remove(helper.find_value_by_name(node.output[0])) elif len(node.input) == 2: # If input number is 2, replace it with add. node.op_type = 'Add' continue elif len(node.input) > 2: # If input number is larger than 2, replace it with n-1 add. input_count = len(node.input) # First node has 2 inputs first_node = onnx.helper.make_node( "Add", [node.input[0], node.input[1]], [node.output[0] + '_replacement_1'], name=node.name + '_replacement_1' ) # Last node has the same output as the original sum node last_node = onnx.helper.make_node( "Add", [node.output[0] + '_replacement_' + str(input_count - 2), node.input[input_count - 1]], [node.output[0]], name=node.name ) g.node.extend([first_node, last_node]) for i in range(2, input_count - 1): new_node = onnx.helper.make_node( "Add", [node.output[0] + '_replacement_' + str(i - 1), node.input[i]], [node.output[0] + '_replacement_' + str(i)], name=node.name + '_replacement_' + str(i) ) g.node.extend([new_node]) node_to_del.append(node) else: logging.error("Sum node must have at least 1 input.") quit(1) while node_to_del: g.node.remove(node_to_del.pop()) topological_sort(g)
def fuse_BN_into_Gemm(g): """Fuse the following BN into the previous Gemm. :param g: the graph """ node_to_remove = [] for node in g.node: # Check for BN and Gemm if node.op_type != 'BatchNormalization': continue gemm_node = helper.find_node_by_output_name(g, node.input[0]) if gemm_node is None: continue if gemm_node.op_type != 'Gemm': continue if len( helper.find_following_nodes_by_input_value_name( g, gemm_node.output[0])) > 1: continue bn_node = node # Get original weights gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1]) gemm_b = helper.constant_to_numpy(gemm_b_node) gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2]) gemm_c = helper.constant_to_numpy(gemm_c_node) bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1]) bn_scale = helper.constant_to_numpy(bn_scale_node) bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2]) bn_bias = helper.constant_to_numpy(bn_bias_node) bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3]) bn_mean = helper.constant_to_numpy(bn_mean_node) bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4]) bn_var = helper.constant_to_numpy(bn_var_node) # Apply attributes # epsilon epsilon = helper.get_attribute_by_name(bn_node, 'epsilon') if epsilon is None: epsilon = 0.00001 else: epsilon = epsilon.f bn_var = bn_var + epsilon # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is None: alpha = 1 else: alpha = alpha.f gemm_b = gemm_b * alpha # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is None: beta = 1 else: beta = beta.f gemm_c = gemm_c * beta # transA transA = helper.get_attribute_by_name(gemm_node, 'transA') if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None and transB.i == 1: gemm_b = gemm_b.transpose() # Calculate new weights new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var) new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias # Replace original weights new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused', new_gemm_b) new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused', new_gemm_c) g.node.extend([new_gemm_b_node, new_gemm_c_node]) node_to_remove.extend([ gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node, bn_mean_node, bn_var_node ]) # Modify attributes # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is not None: alpha.f = 1.0 # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is not None: beta.f = 1.0 # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None: transB.i = 0 # Connect the new graph gemm_node.input[1] = new_gemm_b_node.output[0] gemm_node.input[2] = new_gemm_c_node.output[0] gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0]) gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0]) gemm_b_value.name = new_gemm_b_node.output[0] gemm_c_value.name = new_gemm_c_node.output[0] gemm_value = helper.find_value_by_name(g, gemm_node.output[0]) g.value_info.remove(gemm_value) gemm_node.output[0] = bn_node.output[0] for i in range(1, 5): value = helper.find_value_by_name(g, bn_node.input[i]) g.value_info.remove(value) # Remove useless nodes for node in node_to_remove: g.node.remove(node) topological_sort(g)
def fuse_BN_with_Reshape_into_Gemm(g): """Fuse the following BN into the previous Gemm, even with Reshape or \\ Squeeze and Unsqueeze surrounding. :param g: the graph """ node_to_remove = [] for node in g.node: # Check for BN and Gemm pattern: Gemm A BN B # Find BatchNorm Node if node.op_type != 'BatchNormalization': continue bn_node = node # Find A Node a_node = helper.find_node_by_output_name(g, node.input[0]) if a_node is None or len(a_node.input) == 0: continue # Find Gemm Node gemm_node = helper.find_node_by_output_name(g, a_node.input[0]) if gemm_node is None or gemm_node.op_type != 'Gemm': continue # Find B Node b_node_list = helper.find_following_nodes_by_input_value_name( g, bn_node.output[0]) if len(b_node_list) == 0: the_output = helper.find_output_by_name(g, bn_node.output[0]) if the_output is None: continue b_node = None elif len(b_node_list) > 1: continue else: b_node = b_node_list[0] # Check for branches if len( helper.find_following_nodes_by_input_value_name( g, gemm_node.output[0])) > 1: continue if len( helper.find_following_nodes_by_input_value_name( g, a_node.output[0])) > 1: continue # Check type of A if a_node.op_type == 'Unsqueeze': axes = helper.get_attribute_by_name(a_node, 'axes') if axes.ints != [2]: continue elif a_node.op_type == 'Reshape': a = helper.constant_to_list( helper.find_node_by_output_name(g, a_node.input[1]))[1] if len(a) != 3 or a[2] != 1: continue else: continue # Check type of B if b_node is None: pass elif b_node.op_type == 'Flatten': pass elif b_node.op_type == 'Squeeze': axes = helper.get_attribute_by_name(a_node, 'axes') if axes.ints != [2]: continue elif b_node.op_type == 'Reshape': a = helper.constant_to_list( helper.find_node_by_output_name(g, b_node.input[1]))[1] if len(a) != 2: continue else: continue # Construct new Nodes # Get original weights gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1]) gemm_b = helper.constant_to_numpy(gemm_b_node) gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2]) gemm_c = helper.constant_to_numpy(gemm_c_node) bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1]) bn_scale = helper.constant_to_numpy(bn_scale_node) bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2]) bn_bias = helper.constant_to_numpy(bn_bias_node) bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3]) bn_mean = helper.constant_to_numpy(bn_mean_node) bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4]) bn_var = helper.constant_to_numpy(bn_var_node) # Apply attributes # epsilon epsilon = helper.get_attribute_by_name(bn_node, 'epsilon') if epsilon is None: epsilon = 0.00001 else: epsilon = epsilon.f bn_var = bn_var + epsilon # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is None: alpha = 1 else: alpha = alpha.f gemm_b = gemm_b * alpha # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is None: beta = 1 else: beta = beta.f gemm_c = gemm_c * beta # transA transA = helper.get_attribute_by_name(gemm_node, 'transA') if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None and transB.i == 1: gemm_b = gemm_b.transpose() # Calculate new weights new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var) new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias # Replace original weights new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused', new_gemm_b) new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused', new_gemm_c) g.node.extend([new_gemm_b_node, new_gemm_c_node]) # Modify attributes # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is not None: alpha.f = 1.0 # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is not None: beta.f = 1.0 # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None: transB.i = 0 # Remove useless nodes node_to_remove.extend([ gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node, bn_mean_node, bn_var_node, a_node ]) if a_node.op_type == 'Reshape': node_to_remove.append( helper.find_node_by_output_name(g, a_node.input[1])) if b_node is not None: node_to_remove.append(b_node) if b_node.op_type == 'Reshape': node_to_remove.append( helper.find_node_by_output_name(g, b_node.input[1])) # Delete useless value infos value = helper.find_value_by_name(g, a_node.output[0]) g.value_info.remove(value) if a_node.op_type == 'Reshape': value = helper.find_value_by_name(g, a_node.input[1]) g.value_info.remove(value) for i in range(1, 5): value = helper.find_value_by_name(g, bn_node.input[i]) g.value_info.remove(value) value = helper.find_value_by_name(g, bn_node.output[0]) if value is not None: g.value_info.remove(value) if b_node is not None: value = helper.find_value_by_name(g, gemm_node.output[0]) g.value_info.remove(value) if b_node.op_type == 'Reshape': value = helper.find_value_by_name(g, b_node.input[1]) g.value_info.remove(value) # Connect the new graph # Connect Gemm new weights gemm_node.input[1] = new_gemm_b_node.output[0] gemm_node.input[2] = new_gemm_c_node.output[0] gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0]) gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0]) gemm_b_value.name = new_gemm_b_node.output[0] gemm_b_value.type.tensor_type.shape.dim[ 0].dim_value = new_gemm_b.shape[0] gemm_b_value.type.tensor_type.shape.dim[ 1].dim_value = new_gemm_b.shape[1] gemm_c_value.name = new_gemm_c_node.output[0] if b_node is None: # If b node is None, set the Gemm output as the graph output output_value = helper.find_output_by_name(g, bn_node.output[0]) g.output.remove(output_value) g.output.extend( [helper.find_value_by_name(g, gemm_node.output[0])]) else: # Else, set node B output as gemm output gemm_node.output[0] = b_node.output[0] # Remove useless nodes for node in node_to_remove: g.node.remove(node) topological_sort(g)