def inference_resize_shape(g): for node in g.node: if node.op_type != 'Resize': continue output_value = helper.find_value_by_name(g, node.output[0]) output_value = helper.find_output_by_name( g, node.output[0]) if output_value is None else output_value if output_value is not None: continue # currently, only support 4 input if len(node.input) == 4: # input: X, roi, scales, sizes shape_node = helper.find_node_by_output_name(g, node.input[3]) if shape_node.op_type != 'Constant': continue _, shape_value = helper.constant_to_list(shape_node) output_value = onnx.helper.make_tensor_value_info( node.output[0], onnx.TensorProto.FLOAT, [int(v) for v in shape_value]) g.value_info.extend([output_value]) return True return False
def fuse_Transpose_into_Constant(g): """ Fuse Transpose layers into the Constant layers before :param g: the onnx graph """ node_to_remove = [] for node in g.node: if node.op_type != 'Transpose': continue prev_node = helper.find_node_by_output_name(g, node.input[0]) if prev_node is None or prev_node.op_type != 'Constant': continue pre_shape, data_list = helper.constant_to_list(prev_node) w = np.reshape(data_list, pre_shape) w = w.transpose(node.attribute[0].ints) new_shape = w.shape w = w.flatten() new_tensor = onnx.helper.make_tensor( name=prev_node.name+'_data', data_type=prev_node.attribute[0].t.data_type, dims=new_shape, vals=w.tolist() ) new_node = onnx.helper.make_node( 'Constant', [], [node.output[0]], name=node.output[0], value=new_tensor ) value_between = helper.find_value_by_name(g, prev_node.output[0]) value_type = value_between.type.tensor_type.elem_type g.value_info.remove(value_between) g.node.extend([new_node]) node_to_remove.append(node) node_to_remove.append(prev_node) if new_node.output[0] not in [i.name for i in g.value_info]: new_value = onnx.helper.make_tensor_value_info( name=new_node.output[0], elem_type=value_type, shape=new_shape ) g.value_info.extend([new_value]) if new_node.output[0]: val_info_to_del = helper.find_value_by_name(g, new_node.output[0]) g.value_info.remove(val_info_to_del) for node in node_to_remove: g.node.remove(node) topological_sort(g)
def inference_resize_shape(g): for node in g.node: if node.op_type != 'Resize': continue output_value = helper.find_value_by_name(g, node.output[0]) output_value = helper.find_output_by_name( g, node.output[0]) if output_value is None else output_value if output_value is not None: continue if len(node.input) == 4: # input: X, roi, scales, sizes shape_node = helper.find_node_by_output_name(g, node.input[3]) if shape_node.op_type != 'Constant': continue _, shape_value = helper.constant_to_list(shape_node) output_value = onnx.helper.make_tensor_value_info( node.output[0], onnx.TensorProto.FLOAT, [int(v) for v in shape_value]) g.value_info.extend([output_value]) return True else: # If output shape is not given, inference from scales # Get the input shape input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: continue shape_value = helper.get_shape_from_value_info(input_value) scales_node = helper.find_node_by_output_name(g, node.input[2]) if scales_node.op_type != 'Constant': continue _, scales_value = helper.constant_to_list(scales_node) for i in range(len(shape_value)): shape_value[i] *= scales_value[i] output_value = onnx.helper.make_tensor_value_info( node.output[0], onnx.TensorProto.FLOAT, [int(v) for v in shape_value]) g.value_info.extend([output_value]) return True return False
def replace_ConstantOfShape_with_constant(g): """Replace Shape with Constant.\\ This is the first step of reshape constant folding. :param g: the input graph\\ :return: if anything modified, return true. """ node_to_remove = [] for node in g.node: # Find a Shape if node.op_type != 'ConstantOfShape': continue # Check input input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: input_value = helper.find_input_by_name(g, node.input[0]) if input_value is None or len( input_value.type.tensor_type.shape.dim) == 0: continue # Replace to constant node pre_node = helper.find_node_by_output_name(g, node.input[0]) _, target_shape = helper.constant_to_list(pre_node) value = helper.get_attribute_by_name(node, 'value').i node_name = node.output[0] new_node = helper.list_to_constant(node_name, [target_shape[0]], [value] * target_shape[0]) g.node.extend([new_node]) # remove old node node_to_remove.append(node) # delete value_info val_info_used = sum( [input_value.name in node.input for node in g.node]) if val_info_used == 1: g.value_info.remove(input_value) replaced = True if len(node_to_remove) > 0 else False for node in node_to_remove: g.node.remove(node) topological_sort(g) return replaced
def inference_upsample_shape(g): """For onnx v1.4.1+, onnx cannot inference upsample output shape. Let's\\ do it ourselves. This function only inference the next upsample without\\ output shape each time. :param g: the graph\\ :return: True if any Upsample shape is generated. Otherwise, False. """ for node in g.node: if node.op_type != 'Upsample': continue output_value = helper.find_value_by_name(g, node.output[0]) if output_value is None: output_value = helper.find_output_by_name(g, node.output[0]) if output_value and helper.get_shape_from_value_info(output_value): continue # Get input shape input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: continue #raise RuntimeError("Shape for {} has not been generated.".format(node.input[0])) if not helper.get_shape_from_value_info(input_value): continue #raise RuntimeError("Shape for {} is empty.".format(node.input[0])) input_shape = helper.get_shape_from_value_info(input_value) # Get upsample weight weight_node = helper.find_node_by_output_name(g, node.input[1]) weight_shape, weight = helper.constant_to_list(weight_node) if len(input_shape) != weight_shape[0]: raise RuntimeError( "Unmatch input shape and weight shape: {} vs {}".format( input_shape, weight_shape)) # Calculate shape output_shape = list(input_shape) for i in range(len(output_shape)): output_shape[i] = int(input_shape[i] * weight[i]) output_value = onnx.helper.make_tensor_value_info( node.output[0], input_value.type.tensor_type.elem_type, output_shape) g.value_info.extend([output_value]) return True return False
def replace_mul_to_bn(g): """Replace single Mul node with Batchnorm node. :param g: input graph. :return: """ node_to_del = [] for node in g.node: if node.op_type != 'Mul': continue mul_op_node = node # only support one input node if len(mul_op_node.input) != 2: # OP node and value node continue input_op_node_name = mul_op_node.input[0] mul_value_node = helper.find_node_by_output_name( g, mul_op_node.input[1]) if not mul_value_node or mul_value_node.op_type != 'Constant': continue _, previous_node_output_shape = helper.find_size_shape_from_value( helper.find_value_by_name(g, input_op_node_name)) scale_shape, scale_data = helper.constant_to_list(mul_value_node) # only allow 4 dim data input due to the hardware limitation if len(previous_node_output_shape) != 4: continue # channel dimension c_dim = previous_node_output_shape[1] # only allow channelwise mul or const mul if scale_shape != [1, c_dim, 1, 1] and scale_shape != 1: continue ones = [1.0] * c_dim zeros = [0.0] * c_dim muls = scale_data * c_dim bn_name = mul_op_node.output[0] mean_value_node = helper.list_to_constant(bn_name + '_mean', np.array(zeros).shape, zeros) variance_value_node = helper.list_to_constant(bn_name + '_var', np.array(ones).shape, ones) bias_value_node = helper.list_to_constant(bn_name + '_add', np.array(zeros).shape, zeros) new_mul_value_node = helper.list_to_constant(bn_name + '_mul', np.array(muls).shape, muls) bn_node = onnx.helper.make_node('BatchNormalization', [ input_op_node_name, new_mul_value_node.output[0], bias_value_node.output[0], mean_value_node.output[0], variance_value_node.output[0] ], [mul_op_node.output[0]], name=bn_name, epsilon=0.00000001) mid_val_info = helper.find_value_by_name(g, mul_op_node.output[0]) scale_val_info = helper.find_value_by_name(g, mul_value_node.output[0]) g.value_info.remove(mid_val_info) g.value_info.remove(scale_val_info) g.node.extend([bn_node]) g.node.extend([mean_value_node]) g.node.extend([variance_value_node]) g.node.extend([bias_value_node]) g.node.extend([new_mul_value_node]) node_to_del.extend([mul_op_node]) node_to_del.extend([mul_value_node]) while node_to_del: g.node.remove(node_to_del.pop()) topological_sort(g)
def fuse_mul_and_add_into_gemm(g): node_to_del = [] for node in g.node: if node.op_type != 'Add': continue add_node = node mul_node = helper.find_node_by_output_name(g, add_node.input[0]) if not mul_node or mul_node.op_type != 'Mul': continue mul_const = helper.find_node_by_output_name(g, mul_node.input[1]) if not mul_const or mul_const.op_type != 'Constant': continue add_const = helper.find_node_by_output_name(g, add_node.input[1]) if not add_const or add_const.op_type != 'Constant': continue input_val = helper.find_value_by_name(g, mul_node.input[0]) if not input_val: input_val = helper.find_input_by_name(g, mul_node.input[0]) if not input_val: continue _, input_shape = helper.find_size_shape_from_value(input_val) if not input_shape: continue dim = int(np.prod(input_shape)) if input_shape != [1, dim]: continue mul_const_shape, mul_const_data = helper.constant_to_list(mul_const) add_const_shape, __ = helper.constant_to_list(add_const) if len(mul_const_shape) != 1 or mul_const_shape[0] != dim: continue if len(add_const_shape) != 1 or add_const_shape[0] != dim: continue b_data = np.zeros([dim, dim]) for i in range(dim): b_data[i][i] = mul_const_data[i] b_data = b_data.flatten().tolist() b_tensor = onnx.helper.make_tensor( name=mul_const.name + '_tensor', data_type=mul_const.attribute[0].t.data_type, dims=[dim, dim], vals=b_data) b_const_node = onnx.helper.make_node('Constant', [], [mul_const.output[0]], value=b_tensor, name=mul_const.output[0]) add_const.attribute[0].t.dims.insert(0, 1) gemm_node = onnx.helper.make_node( 'Gemm', [mul_node.input[0], b_const_node.output[0], add_const.output[0]], [add_node.output[0]], name=add_node.output[0]) g.node.extend([gemm_node, b_const_node]) node_to_del.extend([mul_const, mul_node, add_node]) val_info_mid = helper.find_value_by_name(g, mul_node.output[0]) val_info_mul_const = helper.find_value_by_name(g, mul_const.output[0]) val_info_add_const = helper.find_value_by_name(g, add_const.output[0]) if val_info_mid: g.value_info.remove(val_info_mid) if val_info_mul_const: g.value_info.remove(val_info_mul_const) if val_info_add_const: g.value_info.remove(val_info_add_const) while node_to_del: g.node.remove(node_to_del.pop()) topological_sort(g)
def fuse_mul_and_add_into_bn(g): node_to_del = [] for node in g.node: if node.op_type != 'Add': continue add_node = node input_nodes_add = [ helper.find_node_by_output_name(g, input_name) for input_name in add_node.input ] if any([n == None for n in input_nodes_add]): continue mul_node, const_add = None, None for input_node_add in input_nodes_add: if input_node_add.op_type == 'Mul': mul_node = input_node_add elif input_node_add.op_type == 'Constant': const_add = input_node_add else: pass if not mul_node or not const_add: continue data_input_name, const_mul = None, None for input_name in mul_node.input: input_node = helper.find_node_by_output_name(g, input_name) if not input_node: data_input_name = input_name elif input_node.op_type == 'Constant': if not const_mul: const_mul = input_node else: data_input_name = input_name else: data_input_name = input_name if not const_mul: continue scale_shape, scale_data = helper.constant_to_list(const_mul) bais_shape, __ = helper.constant_to_list(const_add) c_dim = len(scale_data) if scale_shape != bais_shape: continue _, previous_node_output_shape = helper.find_size_shape_from_value( helper.find_value_by_name(g, data_input_name)) # only allow 4 dim data input due to the hardware limitation if len(previous_node_output_shape) != 4: continue # check if mul's dim and input channel dimension are matched if previous_node_output_shape[1] != c_dim: continue if scale_shape == [1, c_dim, 1, 1]: # remove all '1' for _ in range(3): const_add.attribute[0].t.dims.remove(1) const_mul.attribute[0].t.dims.remove(1) elif scale_shape == [1, c_dim]: # remove all '1' const_add.attribute[0].t.dims.remove(1) const_mul.attribute[0].t.dims.remove(1) else: continue bn_name = add_node.output[0] const_mean = helper.list_to_constant(bn_name + '_mean', [c_dim], [0.0 for _ in range(c_dim)]) const_var = helper.list_to_constant(bn_name + '_var', [c_dim], [1.0 for _ in range(c_dim)]) bn_node = onnx.helper.make_node( 'BatchNormalization', [data_input_name, const_mul.output[0], const_add.output[0],\ const_mean.output[0], const_var.output[0]], [add_node.output[0]], name=bn_name, epsilon=0.00000001 ) mid_val_info = helper.find_value_by_name(g, mul_node.output[0]) scale_val_info = helper.find_value_by_name(g, const_mul.output[0]) bais_val_info = helper.find_value_by_name(g, const_add.output[0]) g.value_info.remove(mid_val_info) g.value_info.remove(scale_val_info) g.value_info.remove(bais_val_info) new_scale_val_info = onnx.helper.make_tensor_value_info( const_mul.output[0], const_mul.attribute[0].t.data_type, [c_dim]) new_bais_val_info = onnx.helper.make_tensor_value_info( const_add.output[0], const_add.attribute[0].t.data_type, [c_dim]) mean_val_info = onnx.helper.make_tensor_value_info( const_mean.output[0], const_mean.attribute[0].t.data_type, [c_dim]) var_val_info = onnx.helper.make_tensor_value_info( const_var.output[0], const_var.attribute[0].t.data_type, [c_dim]) g.value_info.extend([new_scale_val_info]) g.value_info.extend([new_bais_val_info]) g.value_info.extend([mean_val_info]) g.value_info.extend([var_val_info]) g.node.extend([bn_node]) g.node.extend([const_mean]) g.node.extend([const_var]) node_to_del.extend([mul_node, add_node]) while node_to_del: g.node.remove(node_to_del.pop()) topological_sort(g)
def fuse_BN_with_Reshape_into_Gemm(g): """Fuse the following BN into the previous Gemm, even with Reshape or \\ Squeeze and Unsqueeze surrounding. :param g: the graph """ node_to_remove = [] for node in g.node: # Check for BN and Gemm pattern: Gemm A BN B # Find BatchNorm Node if node.op_type != 'BatchNormalization': continue bn_node = node # Find A Node a_node = helper.find_node_by_output_name(g, node.input[0]) if a_node is None or len(a_node.input) == 0: continue # Find Gemm Node gemm_node = helper.find_node_by_output_name(g, a_node.input[0]) if gemm_node is None or gemm_node.op_type != 'Gemm': continue # Find B Node b_node_list = helper.find_following_nodes_by_input_value_name( g, bn_node.output[0]) if len(b_node_list) == 0: the_output = helper.find_output_by_name(g, bn_node.output[0]) if the_output is None: continue b_node = None elif len(b_node_list) > 1: continue else: b_node = b_node_list[0] # Check for branches if len( helper.find_following_nodes_by_input_value_name( g, gemm_node.output[0])) > 1: continue if len( helper.find_following_nodes_by_input_value_name( g, a_node.output[0])) > 1: continue # Check type of A if a_node.op_type == 'Unsqueeze': axes = helper.get_attribute_by_name(a_node, 'axes') if axes.ints != [2]: continue elif a_node.op_type == 'Reshape': a = helper.constant_to_list( helper.find_node_by_output_name(g, a_node.input[1]))[1] if len(a) != 3 or a[2] != 1: continue else: continue # Check type of B if b_node is None: pass elif b_node.op_type == 'Flatten': pass elif b_node.op_type == 'Squeeze': axes = helper.get_attribute_by_name(a_node, 'axes') if axes.ints != [2]: continue elif b_node.op_type == 'Reshape': a = helper.constant_to_list( helper.find_node_by_output_name(g, b_node.input[1]))[1] if len(a) != 2: continue else: continue # Construct new Nodes # Get original weights gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1]) gemm_b = helper.constant_to_numpy(gemm_b_node) gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2]) gemm_c = helper.constant_to_numpy(gemm_c_node) bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1]) bn_scale = helper.constant_to_numpy(bn_scale_node) bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2]) bn_bias = helper.constant_to_numpy(bn_bias_node) bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3]) bn_mean = helper.constant_to_numpy(bn_mean_node) bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4]) bn_var = helper.constant_to_numpy(bn_var_node) # Apply attributes # epsilon epsilon = helper.get_attribute_by_name(bn_node, 'epsilon') if epsilon is None: epsilon = 0.00001 else: epsilon = epsilon.f bn_var = bn_var + epsilon # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is None: alpha = 1 else: alpha = alpha.f gemm_b = gemm_b * alpha # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is None: beta = 1 else: beta = beta.f gemm_c = gemm_c * beta # transA transA = helper.get_attribute_by_name(gemm_node, 'transA') if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None and transB.i == 1: gemm_b = gemm_b.transpose() # Calculate new weights new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var) new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias # Replace original weights new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused', new_gemm_b) new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused', new_gemm_c) g.node.extend([new_gemm_b_node, new_gemm_c_node]) # Modify attributes # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is not None: alpha.f = 1.0 # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is not None: beta.f = 1.0 # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None: transB.i = 0 # Remove useless nodes node_to_remove.extend([ gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node, bn_mean_node, bn_var_node, a_node ]) if a_node.op_type == 'Reshape': node_to_remove.append( helper.find_node_by_output_name(g, a_node.input[1])) if b_node is not None: node_to_remove.append(b_node) if b_node.op_type == 'Reshape': node_to_remove.append( helper.find_node_by_output_name(g, b_node.input[1])) # Delete useless value infos value = helper.find_value_by_name(g, a_node.output[0]) g.value_info.remove(value) if a_node.op_type == 'Reshape': value = helper.find_value_by_name(g, a_node.input[1]) g.value_info.remove(value) for i in range(1, 5): value = helper.find_value_by_name(g, bn_node.input[i]) g.value_info.remove(value) value = helper.find_value_by_name(g, bn_node.output[0]) if value is not None: g.value_info.remove(value) if b_node is not None: value = helper.find_value_by_name(g, gemm_node.output[0]) g.value_info.remove(value) if b_node.op_type == 'Reshape': value = helper.find_value_by_name(g, b_node.input[1]) g.value_info.remove(value) # Connect the new graph # Connect Gemm new weights gemm_node.input[1] = new_gemm_b_node.output[0] gemm_node.input[2] = new_gemm_c_node.output[0] gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0]) gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0]) gemm_b_value.name = new_gemm_b_node.output[0] gemm_b_value.type.tensor_type.shape.dim[ 0].dim_value = new_gemm_b.shape[0] gemm_b_value.type.tensor_type.shape.dim[ 1].dim_value = new_gemm_b.shape[1] gemm_c_value.name = new_gemm_c_node.output[0] if b_node is None: # If b node is None, set the Gemm output as the graph output output_value = helper.find_output_by_name(g, bn_node.output[0]) g.output.remove(output_value) g.output.extend( [helper.find_value_by_name(g, gemm_node.output[0])]) else: # Else, set node B output as gemm output gemm_node.output[0] = b_node.output[0] # Remove useless nodes for node in node_to_remove: g.node.remove(node) topological_sort(g)
def pattern_matmul_mul_add(g, matmul_node): # Check node match - Mul node next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0]) if len(next_nodes) != 1: return if next_nodes[0].op_type != 'Mul': return mul_node = next_nodes[0] # Check node match - Add node next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0]) if len(next_nodes) != 1: return if next_nodes[0].op_type != 'Add': return add_node = next_nodes[0] # Check Mul weight mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1]) if mul_weight_node.op_type != 'Constant': return weight_size, mul_weight = helper.constant_to_list(mul_weight_node) for i in mul_weight: if i != 1: return channel = weight_size[0] # Check Add weight add_weight_node = helper.find_node_by_output_name(g, add_node.input[1]) if add_weight_node.op_type != 'Constant': return # Check MatMul weight to see if it need weight broadcast matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1]) matmul_weight = helper.constant_to_numpy(matmul_weight_node) if matmul_weight.shape[1] == 1: # Weight broadcast new_matmul_weight = np.tile(matmul_weight, channel) new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight) g.node.remove(matmul_weight_node) g.node.extend([new_matmul_weight_node]) value = helper.find_value_by_name(g, matmul_weight_node.output[0]) if value is not None: g.value_info.remove(value) # Remove Mul node g.node.remove(mul_weight_node) value = helper.find_value_by_name(g, mul_weight_node.output[0]) if value is not None: g.value_info.remove(value) g.node.remove(mul_node) value = helper.find_value_by_name(g, mul_node.output[0]) if value is not None: g.value_info.remove(value) # Fuse Matmul and Add gemm_node = onnx.helper.make_node( 'Gemm', [matmul_node.input[0], matmul_node.input[1], add_node.input[1]], [add_node.output[0]], name = matmul_node.name, alpha = 1.0, beta = 1.0, transA = 0, transB = 0 ) g.node.extend([gemm_node]) # Clean up g.node.remove(matmul_node) g.node.remove(add_node) value = helper.find_value_by_name(g, matmul_node.output[0]) if value is not None: g.value_info.remove(value) other.topological_sort(g)
def replace_add_to_bn(g): """Replace single Add node with Batchnorm node. :param g: input graph. :return: """ node_to_del = [] for node in g.node: if node.op_type != 'Add': continue add_op_node = node # only support one input node if len(add_op_node.input) != 2: # OP node and value node continue input_op_node_name = add_op_node.input[0] add_value_node = helper.find_node_by_output_name( g, add_op_node.input[1]) if not add_value_node or add_value_node.op_type != 'Constant': continue prev_shape_value_info = helper.find_value_by_name( g, input_op_node_name) prev_shape_value_info = helper.find_input_by_name( g, input_op_node_name ) if prev_shape_value_info is None else prev_shape_value_info if prev_shape_value_info is None: continue _, previous_node_output_shape = helper.find_size_shape_from_value( prev_shape_value_info) bias_shape, bias_data = helper.constant_to_list(add_value_node) # channel dimension c_dim = previous_node_output_shape[1] if len( previous_node_output_shape) > 1 else 1 # only allow channelwise mul or const mul if bias_shape != [1, c_dim, 1, 1] and bias_shape != 1: continue ones = [1.0] * c_dim zeros = [0.0] * c_dim # If bias is a scaler, expand it. if len(bias_data) == 1: bias = bias_data * c_dim else: bias = bias_data bn_name = add_op_node.output[0] mean_value_node = helper.list_to_constant(bn_name + '_mean', np.array(zeros).shape, zeros) variance_value_node = helper.list_to_constant(bn_name + '_var', np.array(ones).shape, ones) scale_value_node = helper.list_to_constant(bn_name + '_mul', np.array(ones).shape, ones) new_add_value_node = helper.list_to_constant(bn_name + '_add', np.array(bias).shape, bias) bn_node = onnx.helper.make_node('BatchNormalization', [ input_op_node_name, scale_value_node.output[0], new_add_value_node.output[0], mean_value_node.output[0], variance_value_node.output[0] ], [add_op_node.output[0]], name=bn_name, epsilon=0.00000001) add_val_info = helper.find_value_by_name(g, add_value_node.output[0]) g.value_info.remove(add_val_info) g.node.extend([bn_node]) g.node.extend([mean_value_node]) g.node.extend([variance_value_node]) g.node.extend([scale_value_node]) g.node.extend([new_add_value_node]) node_to_del.extend([add_op_node]) node_to_del.extend([add_value_node]) while node_to_del: g.node.remove(node_to_del.pop()) topological_sort(g)