def fuse_Add_into_Conv(g): """ Fuse Transpose layers into the Constant layers before :param g: the onnx graph """ node_to_remove = [] for node in g.node: if node.op_type != 'Add': continue conv_node = helper.find_node_by_output_name(g, node.input[0]) cons_node = helper.find_node_by_output_name(g, node.input[1]) if conv_node is None or cons_node is None: continue if conv_node.op_type != 'Conv' or cons_node.op_type != 'Constant': continue if len(conv_node.input) > 2: continue # This layer should be fused. Connect constant node into convolution node. add_node = node conv_node.input.extend([cons_node.output[0]]) old_value = helper.find_value_by_name(g, conv_node.output[0]) conv_node.output[0] = add_node.output[0] # Remove origin conv_node_output g.value_info.remove(old_value) # Remove current node node_to_remove.append(add_node) # Apply changes to the model for node in node_to_remove: g.node.remove(node)
def polish_RESIZE_input_param_node(g, resize_node_name): resize_node = helper.find_node_by_output_name(g, resize_node_name) shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3]) shape_data = helper.constant_to_numpy(shape_data_node).astype(int) # handle 0 batch size which is invalid if shape_data[0] == 0: shape_data[0] = 1 pre_node_output_value_info = helper.find_value_by_name( g, resize_node.input[0]) ori_shape = np.array([ pre_node_output_value_info.type.tensor_type.shape.dim[0].dim_value, pre_node_output_value_info.type.tensor_type.shape.dim[1].dim_value, pre_node_output_value_info.type.tensor_type.shape.dim[2].dim_value, pre_node_output_value_info.type.tensor_type.shape.dim[3].dim_value ]) resize_node.input.remove(resize_node.input[3]) resize_scales = np.array(shape_data / ori_shape).astype(float) resize_scale_node = helper.list_to_constant( 'resize_scales_node_' + resize_node.name, resize_scales.shape, resize_scales, data_type=onnx.helper.TensorProto.FLOAT) resize_node.input[2] = resize_scale_node.name g.node.extend([resize_scale_node]) other.topological_sort(g)
def add_bn_before_add(g): for n in g.node: # Find merge node (Add) if n.op_type != 'Add': continue if len(n.input) != 2: continue # Get two inputs input_node_a = helper.find_node_by_output_name(g, n.input[0]) input_node_b = helper.find_node_by_output_name(g, n.input[1]) # Skip constant input add if input_node_a is None or input_node_a.op_type == 'Constant': continue if input_node_b is None or input_node_b.op_type == 'Constant': continue def add_bn_after(prev_node): # Get the channel number from value info value_name = prev_node.output[0] value = helper.find_value_by_name(g, value_name) shape = helper.get_shape_from_value_info(value) channel = shape[1] # Construct 4 weights node_name = value_name + "_nop_bn" ones = [1.0] * channel zeros = [0.0] * channel scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) var_node = helper.list_to_constant(node_name + "_var", [channel], ones) # Construct BN node bn_node = onnx.helper.make_node("BatchNormalization", [ value_name, scale_node.output[0], bias_node.output[0], mean_node.output[0], var_node.output[0] ], [node_name], name=node_name, epsilon=0.00000001) # Reconnect the graph replace_node_input(n, value_name, node_name) # Add node to the graph g.node.extend( [bn_node, scale_node, bias_node, mean_node, var_node]) if not input_node_a.op_type == 'BatchNormalization' or len( helper.find_following_nodes_by_input_value_name( g, input_node_a.output[0])) > 1: add_bn_after(input_node_a) if not input_node_b.op_type == 'BatchNormalization' or len( helper.find_following_nodes_by_input_value_name( g, input_node_b.output[0])) > 1: add_bn_after(input_node_b) topological_sort(g)
def add_bn_on_skip_branch(g): for n in g.node: # Find merge node (Add) if n.op_type != 'Add': continue if len(n.input) != 2: continue # TODO: Still need to consider more cases # Check if skip branch exist input_node_a = helper.find_node_by_output_name(g, n.input[0]) output_of_input_node_a = helper.find_nodes_by_input_name( g, input_node_a.output[0]) input_node_b = helper.find_node_by_output_name(g, n.input[1]) output_of_input_node_b = helper.find_nodes_by_input_name( g, input_node_b.output[0]) if len(output_of_input_node_a) == 1 and len( output_of_input_node_b) == 1: continue if len(output_of_input_node_a) == 2: split_node = input_node_a elif len(output_of_input_node_b) == 2: split_node = input_node_b else: continue # Get the channel number from value info value_name = split_node.output[0] value = helper.find_value_by_name(g, value_name) shape = helper.get_shape_from_value_info(value) channel = shape[1] # Construct 4 weights node_name = value_name + "_nop_bn" ones = [1.0] * channel zeros = [0.0] * channel scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) var_node = helper.list_to_constant(node_name + "_var", [channel], ones) # Construct BN node bn_node = onnx.helper.make_node("BatchNormalization", [ value_name, scale_node.output[0], bias_node.output[0], mean_node.output[0], var_node.output[0] ], [node_name], name=node_name) # Reconnect the graph replace_node_input(n, value_name, node_name) # Add node to the graph g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) topological_sort(g)
def make_UpsamplingBilinear2d_value_info(g, resize_node_name): resize_node = helper.find_node_by_output_name(g, resize_node_name) shape_data_node = helper.find_node_by_output_name(g, resize_node.input[3]) shape_data = helper.constant_to_numpy(shape_data_node).astype(int) l_shape_data = list(shape_data) if l_shape_data[0] == 0: l_shape_data[0] = 1 + l_shape_data[0] shape_data = np.array(l_shape_data) new_output_value_info = onnx.helper.make_tensor_value_info( resize_node.output[0], onnx.helper.TensorProto.FLOAT, shape_data.tolist()) g.value_info.extend([new_output_value_info])
def replace_Reshape_with_Flatten(g): """ Replace Reshape node into Flatten node if applicable. :param g: the onnx graph """ node_to_remove = [] for node in g.node: if node.op_type != 'Reshape': continue found = False # Flatten must be followed by Gemm for i in g.node: if len(i.input) == 0 or i.input[0] != node.output[0]: continue if i.op_type == 'Gemm': found = True break if not found: continue shape_node = helper.find_node_by_output_name(g, node.input[1]) if shape_node.op_type != 'Constant': continue # Replace it node.op_type = "Flatten" for _ in range(len(node.attribute)): node.attribute.pop() shape_value = helper.find_value_by_name(g, shape_node.output[0]) node.input.pop() node_to_remove.append(shape_node) g.value_info.remove(shape_value) for node in node_to_remove: g.node.remove(node)
def inference_resize_shape(g): for node in g.node: if node.op_type != 'Resize': continue output_value = helper.find_value_by_name(g, node.output[0]) output_value = helper.find_output_by_name( g, node.output[0]) if output_value is None else output_value if output_value is not None: continue # currently, only support 4 input if len(node.input) == 4: # input: X, roi, scales, sizes shape_node = helper.find_node_by_output_name(g, node.input[3]) if shape_node.op_type != 'Constant': continue _, shape_value = helper.constant_to_list(shape_node) output_value = onnx.helper.make_tensor_value_info( node.output[0], onnx.TensorProto.FLOAT, [int(v) for v in shape_value]) g.value_info.extend([output_value]) return True return False
def change_first_conv_from_bgr_to_rgb(m): """For input channel format BGR model, use this function to change the first conv weight to adapt the input into RGB. :param m: the model proto """ # Check for first node. g = m.graph input_name = g.input[0].name first_nodes = helper.find_following_nodes_by_input_value_name( g, input_name) if len(first_nodes) > 1: return False first_node = first_nodes[0] # Now we have the first node. Check this first node. if first_node.op_type != 'Conv': return False weight_value = helper.find_value_by_name(g, first_node.input[1]) weight_shape = helper.get_shape_from_value_info(weight_value) if weight_shape[1] != 3: return False # Do weight shuffle weight_node = helper.find_node_by_output_name(g, weight_value.name) weight_np = helper.constant_to_numpy(weight_node) b_channel = np.expand_dims(weight_np[:, 0, :, :], axis=1) g_channel = np.expand_dims(weight_np[:, 1, :, :], axis=1) r_channel = np.expand_dims(weight_np[:, 2, :, :], axis=1) new_np = np.concatenate((r_channel, g_channel, b_channel), axis=1) new_node = helper.numpy_to_constant(weight_value.name, new_np) # Replace the weight and topological sort g.node.remove(weight_node) g.node.extend([new_node]) other.topological_sort(g) return True
def fuse_conv_and_add_into_conv(g): node_to_del = [] for node in g.node: if node.op_type != 'Add': continue add_node = node add_const = helper.find_node_by_output_name(g, add_node.input[1]) if not add_const or add_const.op_type != 'Constant': continue conv_node = helper.find_node_by_output_name(g, add_node.input[0]) if not conv_node or conv_node.op_type != 'Conv': continue weight_node = helper.find_node_by_output_name(g, conv_node.input[1]) if not weight_node or weight_node.op_type != 'Constant': continue m_dim = weight_node.attribute[0].t.dims[0] if add_const.attribute[0].t.dims != [1, m_dim, 1, 1]: continue for _ in range(3): add_const.attribute[0].t.dims.remove(1) conv_node.input.extend([add_const.output[0]]) conv_node.output.pop() conv_node.output.extend([add_node.output[0]]) node_to_del.append(add_node) old_add_const_val_info = helper.find_value_by_name( g, add_node.input[1]) old_conv_output_val_info = helper.find_value_by_name( g, conv_node.output[0]) if old_add_const_val_info: g.value_info.remove(old_add_const_val_info) if old_conv_output_val_info: g.value_info.remove(old_conv_output_val_info) new_add_const_val_info = onnx.helper.make_tensor_value_info( add_const.output[0], add_const.attribute[0].t.data_type, add_const.attribute[0].t.dims) g.value_info.extend([new_add_const_val_info]) while node_to_del: g.node.remove(node_to_del.pop()) topological_sort(g)
def replace_depthwise_1x1_with_bn(g): """Replace 1x1 DepthwiseConv node into BN node if applicable. :param g: the onnx graph """ node_to_remove = [] for node in g.node: # Check op_type if node.op_type != 'Conv': continue # Check attributes attr_map = {attr.name: attr for attr in node.attribute} if "group" not in attr_map or attr_map["group"].i == 1: continue if attr_map["kernel_shape"].ints[0] != 1 or attr_map["kernel_shape"].ints[1] != 1: continue if "pads" in attr_map and sum(attr_map["pads"].ints) != 0: continue # Check scale scale_node = helper.find_node_by_output_name(g, node.input[1]) if scale_node is None or scale_node.attribute[0].t.dims[1] != 1: continue scale_node.attribute[0].t.dims.pop() scale_node.attribute[0].t.dims.pop() scale_node.attribute[0].t.dims.pop() scale_info = helper.find_value_by_name(g, node.input[1]) if scale_info is not None: scale_info.type.tensor_type.shape.dim.pop() scale_info.type.tensor_type.shape.dim.pop() scale_info.type.tensor_type.shape.dim.pop() # Check bias if len(node.input) == 3: bias_name = node.input[2] else: bias_name = node.name + "_bias" bias_node = helper.list_to_constant(bias_name, [attr_map["group"].i], [0.0] * attr_map["group"].i) g.node.extend([bias_node]) # Construct mean and vars mean_name = node.name + "_mean" mean_node = helper.list_to_constant(mean_name, [attr_map["group"].i], [0.0] * attr_map["group"].i) var_name = node.name + "_var" var_node = helper.list_to_constant(var_name, [attr_map["group"].i], [1.0] * attr_map["group"].i) g.node.extend([mean_node, var_node]) # Convert bn_node = onnx.helper.make_node( op_type='BatchNormalization', inputs=[node.input[0], node.input[1], bias_name, mean_name, var_name], outputs=node.output, name=node.name, epsilon=0.00001, momentum=0.9 ) g.node.extend([bn_node]) node_to_remove.append(node) for node in node_to_remove: g.node.remove(node) topological_sort(g)
def fuse_Transpose_into_Constant(g): """ Fuse Transpose layers into the Constant layers before :param g: the onnx graph """ node_to_remove = [] for node in g.node: if node.op_type != 'Transpose': continue prev_node = helper.find_node_by_output_name(g, node.input[0]) if prev_node is None or prev_node.op_type != 'Constant': continue pre_shape, data_list = helper.constant_to_list(prev_node) w = np.reshape(data_list, pre_shape) w = w.transpose(node.attribute[0].ints) new_shape = w.shape w = w.flatten() new_tensor = onnx.helper.make_tensor( name=prev_node.name+'_data', data_type=prev_node.attribute[0].t.data_type, dims=new_shape, vals=w.tolist() ) new_node = onnx.helper.make_node( 'Constant', [], [node.output[0]], name=node.output[0], value=new_tensor ) value_between = helper.find_value_by_name(g, prev_node.output[0]) value_type = value_between.type.tensor_type.elem_type g.value_info.remove(value_between) g.node.extend([new_node]) node_to_remove.append(node) node_to_remove.append(prev_node) if new_node.output[0] not in [i.name for i in g.value_info]: new_value = onnx.helper.make_tensor_value_info( name=new_node.output[0], elem_type=value_type, shape=new_shape ) g.value_info.extend([new_value]) if new_node.output[0]: val_info_to_del = helper.find_value_by_name(g, new_node.output[0]) g.value_info.remove(val_info_to_del) for node in node_to_remove: g.node.remove(node) topological_sort(g)
def inference_resize_shape(g): for node in g.node: if node.op_type != 'Resize': continue output_value = helper.find_value_by_name(g, node.output[0]) output_value = helper.find_output_by_name( g, node.output[0]) if output_value is None else output_value if output_value is not None: continue if len(node.input) == 4: # input: X, roi, scales, sizes shape_node = helper.find_node_by_output_name(g, node.input[3]) if shape_node.op_type != 'Constant': continue _, shape_value = helper.constant_to_list(shape_node) output_value = onnx.helper.make_tensor_value_info( node.output[0], onnx.TensorProto.FLOAT, [int(v) for v in shape_value]) g.value_info.extend([output_value]) return True else: # If output shape is not given, inference from scales # Get the input shape input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: continue shape_value = helper.get_shape_from_value_info(input_value) scales_node = helper.find_node_by_output_name(g, node.input[2]) if scales_node.op_type != 'Constant': continue _, scales_value = helper.constant_to_list(scales_node) for i in range(len(shape_value)): shape_value[i] *= scales_value[i] output_value = onnx.helper.make_tensor_value_info( node.output[0], onnx.TensorProto.FLOAT, [int(v) for v in shape_value]) g.value_info.extend([output_value]) return True return False
def replace_ConstantOfShape_with_constant(g): """Replace Shape with Constant.\\ This is the first step of reshape constant folding. :param g: the input graph\\ :return: if anything modified, return true. """ node_to_remove = [] for node in g.node: # Find a Shape if node.op_type != 'ConstantOfShape': continue # Check input input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: input_value = helper.find_input_by_name(g, node.input[0]) if input_value is None or len( input_value.type.tensor_type.shape.dim) == 0: continue # Replace to constant node pre_node = helper.find_node_by_output_name(g, node.input[0]) _, target_shape = helper.constant_to_list(pre_node) value = helper.get_attribute_by_name(node, 'value').i node_name = node.output[0] new_node = helper.list_to_constant(node_name, [target_shape[0]], [value] * target_shape[0]) g.node.extend([new_node]) # remove old node node_to_remove.append(node) # delete value_info val_info_used = sum( [input_value.name in node.input for node in g.node]) if val_info_used == 1: g.value_info.remove(input_value) replaced = True if len(node_to_remove) > 0 else False for node in node_to_remove: g.node.remove(node) topological_sort(g) return replaced
def add_bn_before_activation(g): activation_nodes = set(['Relu', 'Clip', 'PRelu', 'LeakyRelu']) previous_nodes = set(['Conv', 'BatchNormalization']) for n in g.node: # Find activation node if n.op_type not in activation_nodes: continue # Get input input_node = helper.find_node_by_output_name(g, n.input[0]) if input_node is None or input_node.op_type in previous_nodes: continue def add_bn_after(prev_node): # Get the channel number from value info value_name = prev_node.output[0] value = helper.find_value_by_name(g, value_name) shape = helper.get_shape_from_value_info(value) channel = shape[1] # Construct 4 weights node_name = value_name + "_nop_bn" ones = [1.0] * channel zeros = [0.0] * channel scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) var_node = helper.list_to_constant(node_name + "_var", [channel], ones) # Construct BN node bn_node = onnx.helper.make_node("BatchNormalization", [ value_name, scale_node.output[0], bias_node.output[0], mean_node.output[0], var_node.output[0] ], [node_name], name=node_name, epsilon=0.00000001) # Reconnect the graph replace_node_input(n, value_name, node_name) # Add node to the graph g.node.extend( [bn_node, scale_node, bias_node, mean_node, var_node]) add_bn_after(input_node) topological_sort(g)
def rename_output_name(g, original_name, new_name): # Output output_value = helper.find_output_by_name(g, original_name) if output_value is None: logging.error("Cannot find output value named " + original_name) return output_value.name = new_name # Value Info value_info = helper.find_value_by_name(g, original_name) if value_info is not None: value_info.name = new_name # Node output node = helper.find_node_by_output_name(g, original_name) node.output[0] = new_name # Node input nodes = helper.find_nodes_by_input_name(g, original_name) for node in nodes: replace_node_input(node, original_name, new_name)
def duplicate_param_shared_constant(g): for node in g.node: input_names = set() for n, input_node_name in enumerate(node.input): param_data_node = helper.find_node_by_output_name(g, input_node_name) if param_data_node is None or param_data_node.op_type != 'Constant': continue if param_data_node.name not in input_names: input_names.add(input_node_name) continue duplicated_node = copy.deepcopy(param_data_node) new_node_name = param_data_node.name + '_' + str(n) duplicated_node.name = new_node_name duplicated_node.output[0] = new_node_name node.input[n] = new_node_name g.node.extend([duplicated_node])
def fuse_consecutive_reducemean(g): node_to_del = [] for node in g.node: # Find consecutive ReduceMean if node.op_type != 'ReduceMean': continue pre_node = helper.find_node_by_output_name(g, node.input[0]) if pre_node is None or pre_node.op_type != 'ReduceMean': continue # Check attributes pre_keepdims = helper.get_var_attribute_by_name( pre_node, 'keepdims', 'int') pre_axes = helper.get_list_attribute_by_name(pre_node, 'axes', 'int') cur_keepdims = helper.get_var_attribute_by_name( node, 'keepdims', 'int') cur_axes = helper.get_list_attribute_by_name(node, 'axes', 'int') if pre_keepdims != 0 or cur_keepdims != 0: continue axes = sorted(pre_axes + cur_axes) if axes != [2, 3]: continue # Merge two ReduceMean into GlobalAveragePool. new_gap_node = onnx.helper.make_node('GlobalAveragePool', [pre_node.input[0]], [node.output[0] + '_intermedia'], name=node.name + '_gap') new_flatten_node = onnx.helper.make_node( 'Flatten', [node.output[0] + '_intermedia'], [node.output[0]], name=node.name + '_flatten', axis=1) # Clean up g.node.extend([new_gap_node, new_flatten_node]) node_to_del.extend([pre_node, node]) mid_val_info = helper.find_value_by_name(g, node.input[0]) if mid_val_info: g.value_info.remove(mid_val_info) while node_to_del: node = node_to_del.pop() g.node.remove(node) topological_sort(g)
def inference_upsample_shape(g): """For onnx v1.4.1+, onnx cannot inference upsample output shape. Let's\\ do it ourselves. This function only inference the next upsample without\\ output shape each time. :param g: the graph\\ :return: True if any Upsample shape is generated. Otherwise, False. """ for node in g.node: if node.op_type != 'Upsample': continue output_value = helper.find_value_by_name(g, node.output[0]) if output_value is None: output_value = helper.find_output_by_name(g, node.output[0]) if output_value and helper.get_shape_from_value_info(output_value): continue # Get input shape input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: continue #raise RuntimeError("Shape for {} has not been generated.".format(node.input[0])) if not helper.get_shape_from_value_info(input_value): continue #raise RuntimeError("Shape for {} is empty.".format(node.input[0])) input_shape = helper.get_shape_from_value_info(input_value) # Get upsample weight weight_node = helper.find_node_by_output_name(g, node.input[1]) weight_shape, weight = helper.constant_to_list(weight_node) if len(input_shape) != weight_shape[0]: raise RuntimeError( "Unmatch input shape and weight shape: {} vs {}".format( input_shape, weight_shape)) # Calculate shape output_shape = list(input_shape) for i in range(len(output_shape)): output_shape[i] = int(input_shape[i] * weight[i]) output_value = onnx.helper.make_tensor_value_info( node.output[0], input_value.type.tensor_type.elem_type, output_shape) g.value_info.extend([output_value]) return True return False
def transpose_B_in_Gemm(g): """ If transB is set in Gemm, transpose it :param g: the onnx graph """ for node in g.node: if node.op_type != 'Gemm': continue do_it = False for attr in node.attribute: if attr.name == "transB": if attr.i == 1: attr.i = 0 do_it = True break if not do_it: continue # Transpose the weight and its output value w_node = helper.find_node_by_output_name(g, node.input[1]) w_output = helper.find_value_by_name(g, node.input[1]) dim_0 = w_output.type.tensor_type.shape.dim[0].dim_value dim_1 = w_output.type.tensor_type.shape.dim[1].dim_value w_output.type.tensor_type.shape.dim[0].dim_value = dim_1 w_output.type.tensor_type.shape.dim[1].dim_value = dim_0 w_node.attribute[0].t.dims[0] = dim_1 w_node.attribute[0].t.dims[1] = dim_0 if w_node.attribute[0].t.raw_data: raw_data = w_node.attribute[0].t.raw_data fl_data = [i[0] for i in struct.iter_unpack('f', raw_data)] else: fl_data = w_node.attribute[0].t.float_data w = np.reshape(fl_data, (dim_0, dim_1)) w = w.transpose((1, 0)).flatten() if w_node.attribute[0].t.raw_data: buf = struct.pack('%sf' % len(w), *w) w_node.attribute[0].t.raw_data = buf else: for i in range(len(fl_data)): w_node.attribute[0].t.float_data[i] = w[i]
def fuse_consecutive_transposes(g): node_to_del = [] for node in g.node: if node.op_type != 'Transpose': continue pre_node = helper.find_node_by_output_name(g, node.input[0]) if pre_node.op_type != 'Transpose': continue pre_permutation = list(pre_node.attribute[0].ints) cur_permutation = list(node.attribute[0].ints) if len(pre_permutation) != len(cur_permutation): continue new_permutation = [] for ind in cur_permutation: new_permutation.append(pre_permutation[ind]) new_trans_node = onnx.helper.make_node( 'Transpose', [pre_node.input[0]], [node.output[0]], name=node.name, perm=new_permutation ) g.node.extend([new_trans_node]) node_to_del.extend([pre_node, node]) mid_val_info = helper.find_value_by_name(g, node.input[0]) if mid_val_info: g.value_info.remove(mid_val_info) while node_to_del: node = node_to_del.pop() g.node.remove(node) topological_sort(g)
def replace_mul_to_bn(g): """Replace single Mul node with Batchnorm node. :param g: input graph. :return: """ node_to_del = [] for node in g.node: if node.op_type != 'Mul': continue mul_op_node = node # only support one input node if len(mul_op_node.input) != 2: # OP node and value node continue input_op_node_name = mul_op_node.input[0] mul_value_node = helper.find_node_by_output_name( g, mul_op_node.input[1]) if not mul_value_node or mul_value_node.op_type != 'Constant': continue _, previous_node_output_shape = helper.find_size_shape_from_value( helper.find_value_by_name(g, input_op_node_name)) scale_shape, scale_data = helper.constant_to_list(mul_value_node) # only allow 4 dim data input due to the hardware limitation if len(previous_node_output_shape) != 4: continue # channel dimension c_dim = previous_node_output_shape[1] # only allow channelwise mul or const mul if scale_shape != [1, c_dim, 1, 1] and scale_shape != 1: continue ones = [1.0] * c_dim zeros = [0.0] * c_dim muls = scale_data * c_dim bn_name = mul_op_node.output[0] mean_value_node = helper.list_to_constant(bn_name + '_mean', np.array(zeros).shape, zeros) variance_value_node = helper.list_to_constant(bn_name + '_var', np.array(ones).shape, ones) bias_value_node = helper.list_to_constant(bn_name + '_add', np.array(zeros).shape, zeros) new_mul_value_node = helper.list_to_constant(bn_name + '_mul', np.array(muls).shape, muls) bn_node = onnx.helper.make_node('BatchNormalization', [ input_op_node_name, new_mul_value_node.output[0], bias_value_node.output[0], mean_value_node.output[0], variance_value_node.output[0] ], [mul_op_node.output[0]], name=bn_name, epsilon=0.00000001) mid_val_info = helper.find_value_by_name(g, mul_op_node.output[0]) scale_val_info = helper.find_value_by_name(g, mul_value_node.output[0]) g.value_info.remove(mid_val_info) g.value_info.remove(scale_val_info) g.node.extend([bn_node]) g.node.extend([mean_value_node]) g.node.extend([variance_value_node]) g.node.extend([bias_value_node]) g.node.extend([new_mul_value_node]) node_to_del.extend([mul_op_node]) node_to_del.extend([mul_value_node]) while node_to_del: g.node.remove(node_to_del.pop()) topological_sort(g)
def replace_ReduceMean_with_GlobalAveragePool(g): """ Replace ReduceMean with GlobalAveragePool node when available. If there is preceeded Transpose, check the Transpose and the ReduceMean together. If the keep_dims is set to 0, add a Flatten. :param g: the input graph """ node_to_remove = [] for node in g.node: # Find a ReduceMean layer if node.op_type != 'ReduceMean': continue # Find if it have previous Transpose and its attribute meet the need. prev_node = helper.find_node_by_output_name(g, node.input[0]) if prev_node is not None and prev_node.op_type != 'Transpose': prev_node = None if prev_node is not None: perm = helper.get_list_attribute_by_name(prev_node, 'perm', 'int') if perm != [0, 2, 3, 1]: prev_node = None # Check attributes axes = helper.get_list_attribute_by_name(node, 'axes', 'int') keepdims = helper.get_var_attribute_by_name(node, 'keepdims', 'int') if axes is None: continue if prev_node is None and axes != [2, 3]: continue if prev_node is not None and axes != [1, 2]: continue if keepdims is None: keepdims = 1 # Replace it with GlobalAveragePool if prev_node: input_list = prev_node.input else: input_list = node.input if keepdims == 1: output_list = node.output else: output_list = [node.output[0] + '_before_flatten'] flatten_node = onnx.helper.make_node("Flatten", output_list, node.output, name=node.name + "_flatten", axis=1) g.node.extend([flatten_node]) new_node = onnx.helper.make_node("GlobalAveragePool", input_list, output_list, name=node.name) g.node.extend([new_node]) node_to_remove.append(node) if prev_node: value = helper.find_value_by_name(g, prev_node.output[0]) if value: g.value_info.remove(value) node_to_remove.append(prev_node) for node in node_to_remove: g.node.remove(node) topological_sort(g)
def split_ConvTranspose(model): """To feed our compiler, split ConvTranspose into Upsample and Conv. :param model: the model """ node_to_delete = [] # Change model properties for upsample. if model.ir_version < 3: print("Warning: Current model IR version is not fully supported.") model.ir_version = 4 model.opset_import[0].version = 9 g = model.graph # Get a Convtranspose layer for node in g.node: # Find a Flatten node if node.op_type != 'ConvTranspose': continue # Check auto_pad auto_pad_proto = helper.get_attribute_by_name(node, "auto_pad") if auto_pad_proto is not None: print("Currently not split auto_pad ConvTranspose") continue # Check output_shape output_shape_proto = helper.get_attribute_by_name(node, "output_shape") if output_shape_proto is not None: print("Currently not split output_shape ConvTranspose") continue # Get input shape input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: input_value = helper.find_input_by_name(g, node.input[0]) if input_value is None: print("Cannot get value info named {}.".format(node.input[0])) exit(1) input_shape = helper.get_shape_from_value_info(input_value) # Get attrbutes attr = deconv_to_conv_info_extraction(input_shape, node) # Generate Upsample scales upsample_output_shape = list(input_shape) upsample_output_shape[2] = (input_shape[2] - 1) * attr["strides"][0] + 1 upsample_output_shape[3] = (input_shape[3] - 1) * attr["strides"][1] + 1 upsample_node_name = node.name + "_inner_upsample" upsample_scale_name = upsample_node_name + "_scales" scales_np = np.ones([4]).astype('float32') scales_np[2] = float(upsample_output_shape[2]) / input_shape[2] scales_np[3] = float(upsample_output_shape[3]) / input_shape[3] scales_node = helper.numpy_to_constant(upsample_scale_name, scales_np) # Generate a Upsample layer and an internal value info upsample_node = onnx.helper.make_node( "Upsample", [node.input[0], upsample_scale_name], [upsample_node_name], name=upsample_node_name, mode="zeros" ) upsample_value_info = onnx.helper.make_tensor_value_info( upsample_node_name, input_value.type.tensor_type.elem_type, upsample_output_shape ) # Check the weight layer, it may need a transpose if attr["group"] != input_shape[1]: weight_node = helper.find_node_by_output_name(g, node.input[1]) weight_np = helper.constant_to_numpy(weight_node) new_weight_np = np.transpose(weight_np, [1, 0, 2, 3]) new_weight_node = helper.numpy_to_constant(node.input[1], new_weight_np) node_to_delete.append(weight_node) g.node.extend([new_weight_node]) value = helper.find_value_by_name(g, node.input[1]) g.value_info.remove(value) # Generate a Conv layer conv_node_name = node.name + "_inner_conv" conv_node_input = [upsample_node_name] conv_node_input.extend(node.input[1:]) conv_node = onnx.helper.make_node( "Conv", conv_node_input, [node.output[0]], name=conv_node_name, pads=[int(i) for i in attr["conv_pads"]], dilations=[int(i) for i in attr["dilations"]], group=int(attr["group"]), kernel_shape=[int(i) for i in attr["kernel_shape"]], strides=[int(1), int(1)] ) # Reconnect the graph g.node.extend([scales_node, upsample_node, conv_node]) g.value_info.extend([upsample_value_info]) node_to_delete.append(node) # Delete useless nodes for node in node_to_delete: g.node.remove(node) topological_sort(g)
def fuse_BN_into_Gemm(g): """Fuse the following BN into the previous Gemm. :param g: the graph """ node_to_remove = [] for node in g.node: # Check for BN and Gemm if node.op_type != 'BatchNormalization': continue gemm_node = helper.find_node_by_output_name(g, node.input[0]) if gemm_node is None: continue if gemm_node.op_type != 'Gemm': continue if len( helper.find_following_nodes_by_input_value_name( g, gemm_node.output[0])) > 1: continue bn_node = node # Get original weights gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1]) gemm_b = helper.constant_to_numpy(gemm_b_node) gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2]) gemm_c = helper.constant_to_numpy(gemm_c_node) bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1]) bn_scale = helper.constant_to_numpy(bn_scale_node) bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2]) bn_bias = helper.constant_to_numpy(bn_bias_node) bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3]) bn_mean = helper.constant_to_numpy(bn_mean_node) bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4]) bn_var = helper.constant_to_numpy(bn_var_node) # Apply attributes # epsilon epsilon = helper.get_attribute_by_name(bn_node, 'epsilon') if epsilon is None: epsilon = 0.00001 else: epsilon = epsilon.f bn_var = bn_var + epsilon # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is None: alpha = 1 else: alpha = alpha.f gemm_b = gemm_b * alpha # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is None: beta = 1 else: beta = beta.f gemm_c = gemm_c * beta # transA transA = helper.get_attribute_by_name(gemm_node, 'transA') if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None and transB.i == 1: gemm_b = gemm_b.transpose() # Calculate new weights new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var) new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias # Replace original weights new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused', new_gemm_b) new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused', new_gemm_c) g.node.extend([new_gemm_b_node, new_gemm_c_node]) node_to_remove.extend([ gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node, bn_mean_node, bn_var_node ]) # Modify attributes # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is not None: alpha.f = 1.0 # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is not None: beta.f = 1.0 # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None: transB.i = 0 # Connect the new graph gemm_node.input[1] = new_gemm_b_node.output[0] gemm_node.input[2] = new_gemm_c_node.output[0] gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0]) gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0]) gemm_b_value.name = new_gemm_b_node.output[0] gemm_c_value.name = new_gemm_c_node.output[0] gemm_value = helper.find_value_by_name(g, gemm_node.output[0]) g.value_info.remove(gemm_value) gemm_node.output[0] = bn_node.output[0] for i in range(1, 5): value = helper.find_value_by_name(g, bn_node.input[i]) g.value_info.remove(value) # Remove useless nodes for node in node_to_remove: g.node.remove(node) topological_sort(g)
def replace_dilated_conv(g): """ If the dilation of a convolution is not (1, 1), replace it with a regular convolution with an expanded kernel. :param g: the input graph """ node_to_remove = [] for node in g.node: # Check if this is a conv layer if node.op_type != 'Conv': continue # Check if this has dilation has_dilations = False has_strides = False for attr in node.attribute: if attr.name == "dilations": dilations = list(attr.ints) if dilations != [1, 1]: has_dilations = True if attr.name == "strides": strides = list(attr.ints) if strides != [1, 1]: has_strides = True if has_dilations and has_strides: print("Warning: Both strides and dilations are set in ", node.name) continue if not has_dilations: continue # Construct new kernel w_node = helper.find_node_by_output_name(g, node.input[1]) w_output = helper.find_value_by_name(g, node.input[1]) shape = list(w_node.attribute[0].t.dims) # get original weight from float_data or raw data weight = list(w_node.attribute[0].t.float_data) if len(weight) == 0: # Unpack from raw data raw_data = w_node.attribute[0].t.raw_data weight = [i[0] for i in struct.iter_unpack('f', raw_data)] weight = np.array(weight) weight = np.reshape(weight, shape) new_shape = copy.copy(shape) new_shape[2] = 1 + (shape[2] - 1) * dilations[0] new_shape[3] = 1 + (shape[3] - 1) * dilations[1] new_weight = np.zeros(new_shape) for batch in range(shape[0]): for ch in range(shape[1]): for h in range(shape[2]): nh = h * dilations[0] for w in range(shape[3]): nw = w * dilations[1] new_weight[batch, ch, nh, nw] = weight[batch, ch, h, w] tensor = onnx.helper.make_tensor(w_node.attribute[0].t.name, w_node.attribute[0].t.data_type, new_shape, new_weight.ravel()) new_w_node = onnx.helper.make_node("Constant", [], list(w_node.output), name=w_node.name, value=tensor) g.node.extend([new_w_node]) node_to_remove.append(w_node) # Modify attributes and value info shapes w_output.type.tensor_type.shape.dim[2].dim_value = new_shape[2] w_output.type.tensor_type.shape.dim[3].dim_value = new_shape[3] for attr in node.attribute: if attr.name == "kernel_shape": attr.ints[0] = new_shape[2] attr.ints[1] = new_shape[3] if attr.name == "dilations": attr.ints[0] = 1 attr.ints[1] = 1 # Remove old weight nodes for node in node_to_remove: g.node.remove(node)
def fuse_mul_and_add_into_bn(g): node_to_del = [] for node in g.node: if node.op_type != 'Add': continue add_node = node input_nodes_add = [ helper.find_node_by_output_name(g, input_name) for input_name in add_node.input ] if any([n == None for n in input_nodes_add]): continue mul_node, const_add = None, None for input_node_add in input_nodes_add: if input_node_add.op_type == 'Mul': mul_node = input_node_add elif input_node_add.op_type == 'Constant': const_add = input_node_add else: pass if not mul_node or not const_add: continue data_input_name, const_mul = None, None for input_name in mul_node.input: input_node = helper.find_node_by_output_name(g, input_name) if not input_node: data_input_name = input_name elif input_node.op_type == 'Constant': if not const_mul: const_mul = input_node else: data_input_name = input_name else: data_input_name = input_name if not const_mul: continue scale_shape, scale_data = helper.constant_to_list(const_mul) bais_shape, __ = helper.constant_to_list(const_add) c_dim = len(scale_data) if scale_shape != bais_shape: continue _, previous_node_output_shape = helper.find_size_shape_from_value( helper.find_value_by_name(g, data_input_name)) # only allow 4 dim data input due to the hardware limitation if len(previous_node_output_shape) != 4: continue # check if mul's dim and input channel dimension are matched if previous_node_output_shape[1] != c_dim: continue if scale_shape == [1, c_dim, 1, 1]: # remove all '1' for _ in range(3): const_add.attribute[0].t.dims.remove(1) const_mul.attribute[0].t.dims.remove(1) elif scale_shape == [1, c_dim]: # remove all '1' const_add.attribute[0].t.dims.remove(1) const_mul.attribute[0].t.dims.remove(1) else: continue bn_name = add_node.output[0] const_mean = helper.list_to_constant(bn_name + '_mean', [c_dim], [0.0 for _ in range(c_dim)]) const_var = helper.list_to_constant(bn_name + '_var', [c_dim], [1.0 for _ in range(c_dim)]) bn_node = onnx.helper.make_node( 'BatchNormalization', [data_input_name, const_mul.output[0], const_add.output[0],\ const_mean.output[0], const_var.output[0]], [add_node.output[0]], name=bn_name, epsilon=0.00000001 ) mid_val_info = helper.find_value_by_name(g, mul_node.output[0]) scale_val_info = helper.find_value_by_name(g, const_mul.output[0]) bais_val_info = helper.find_value_by_name(g, const_add.output[0]) g.value_info.remove(mid_val_info) g.value_info.remove(scale_val_info) g.value_info.remove(bais_val_info) new_scale_val_info = onnx.helper.make_tensor_value_info( const_mul.output[0], const_mul.attribute[0].t.data_type, [c_dim]) new_bais_val_info = onnx.helper.make_tensor_value_info( const_add.output[0], const_add.attribute[0].t.data_type, [c_dim]) mean_val_info = onnx.helper.make_tensor_value_info( const_mean.output[0], const_mean.attribute[0].t.data_type, [c_dim]) var_val_info = onnx.helper.make_tensor_value_info( const_var.output[0], const_var.attribute[0].t.data_type, [c_dim]) g.value_info.extend([new_scale_val_info]) g.value_info.extend([new_bais_val_info]) g.value_info.extend([mean_val_info]) g.value_info.extend([var_val_info]) g.node.extend([bn_node]) g.node.extend([const_mean]) g.node.extend([const_var]) node_to_del.extend([mul_node, add_node]) while node_to_del: g.node.remove(node_to_del.pop()) topological_sort(g)
def fuse_BN_with_Reshape_into_Gemm(g): """Fuse the following BN into the previous Gemm, even with Reshape or \\ Squeeze and Unsqueeze surrounding. :param g: the graph """ node_to_remove = [] for node in g.node: # Check for BN and Gemm pattern: Gemm A BN B # Find BatchNorm Node if node.op_type != 'BatchNormalization': continue bn_node = node # Find A Node a_node = helper.find_node_by_output_name(g, node.input[0]) if a_node is None or len(a_node.input) == 0: continue # Find Gemm Node gemm_node = helper.find_node_by_output_name(g, a_node.input[0]) if gemm_node is None or gemm_node.op_type != 'Gemm': continue # Find B Node b_node_list = helper.find_following_nodes_by_input_value_name( g, bn_node.output[0]) if len(b_node_list) == 0: the_output = helper.find_output_by_name(g, bn_node.output[0]) if the_output is None: continue b_node = None elif len(b_node_list) > 1: continue else: b_node = b_node_list[0] # Check for branches if len( helper.find_following_nodes_by_input_value_name( g, gemm_node.output[0])) > 1: continue if len( helper.find_following_nodes_by_input_value_name( g, a_node.output[0])) > 1: continue # Check type of A if a_node.op_type == 'Unsqueeze': axes = helper.get_attribute_by_name(a_node, 'axes') if axes.ints != [2]: continue elif a_node.op_type == 'Reshape': a = helper.constant_to_list( helper.find_node_by_output_name(g, a_node.input[1]))[1] if len(a) != 3 or a[2] != 1: continue else: continue # Check type of B if b_node is None: pass elif b_node.op_type == 'Flatten': pass elif b_node.op_type == 'Squeeze': axes = helper.get_attribute_by_name(a_node, 'axes') if axes.ints != [2]: continue elif b_node.op_type == 'Reshape': a = helper.constant_to_list( helper.find_node_by_output_name(g, b_node.input[1]))[1] if len(a) != 2: continue else: continue # Construct new Nodes # Get original weights gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1]) gemm_b = helper.constant_to_numpy(gemm_b_node) gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2]) gemm_c = helper.constant_to_numpy(gemm_c_node) bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1]) bn_scale = helper.constant_to_numpy(bn_scale_node) bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2]) bn_bias = helper.constant_to_numpy(bn_bias_node) bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3]) bn_mean = helper.constant_to_numpy(bn_mean_node) bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4]) bn_var = helper.constant_to_numpy(bn_var_node) # Apply attributes # epsilon epsilon = helper.get_attribute_by_name(bn_node, 'epsilon') if epsilon is None: epsilon = 0.00001 else: epsilon = epsilon.f bn_var = bn_var + epsilon # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is None: alpha = 1 else: alpha = alpha.f gemm_b = gemm_b * alpha # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is None: beta = 1 else: beta = beta.f gemm_c = gemm_c * beta # transA transA = helper.get_attribute_by_name(gemm_node, 'transA') if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None and transB.i == 1: gemm_b = gemm_b.transpose() # Calculate new weights new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var) new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias # Replace original weights new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused', new_gemm_b) new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused', new_gemm_c) g.node.extend([new_gemm_b_node, new_gemm_c_node]) # Modify attributes # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is not None: alpha.f = 1.0 # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is not None: beta.f = 1.0 # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None: transB.i = 0 # Remove useless nodes node_to_remove.extend([ gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node, bn_mean_node, bn_var_node, a_node ]) if a_node.op_type == 'Reshape': node_to_remove.append( helper.find_node_by_output_name(g, a_node.input[1])) if b_node is not None: node_to_remove.append(b_node) if b_node.op_type == 'Reshape': node_to_remove.append( helper.find_node_by_output_name(g, b_node.input[1])) # Delete useless value infos value = helper.find_value_by_name(g, a_node.output[0]) g.value_info.remove(value) if a_node.op_type == 'Reshape': value = helper.find_value_by_name(g, a_node.input[1]) g.value_info.remove(value) for i in range(1, 5): value = helper.find_value_by_name(g, bn_node.input[i]) g.value_info.remove(value) value = helper.find_value_by_name(g, bn_node.output[0]) if value is not None: g.value_info.remove(value) if b_node is not None: value = helper.find_value_by_name(g, gemm_node.output[0]) g.value_info.remove(value) if b_node.op_type == 'Reshape': value = helper.find_value_by_name(g, b_node.input[1]) g.value_info.remove(value) # Connect the new graph # Connect Gemm new weights gemm_node.input[1] = new_gemm_b_node.output[0] gemm_node.input[2] = new_gemm_c_node.output[0] gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0]) gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0]) gemm_b_value.name = new_gemm_b_node.output[0] gemm_b_value.type.tensor_type.shape.dim[ 0].dim_value = new_gemm_b.shape[0] gemm_b_value.type.tensor_type.shape.dim[ 1].dim_value = new_gemm_b.shape[1] gemm_c_value.name = new_gemm_c_node.output[0] if b_node is None: # If b node is None, set the Gemm output as the graph output output_value = helper.find_output_by_name(g, bn_node.output[0]) g.output.remove(output_value) g.output.extend( [helper.find_value_by_name(g, gemm_node.output[0])]) else: # Else, set node B output as gemm output gemm_node.output[0] = b_node.output[0] # Remove useless nodes for node in node_to_remove: g.node.remove(node) topological_sort(g)
def pattern_matmul_mul_add(g, matmul_node): # Check node match - Mul node next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0]) if len(next_nodes) != 1: return if next_nodes[0].op_type != 'Mul': return mul_node = next_nodes[0] # Check node match - Add node next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0]) if len(next_nodes) != 1: return if next_nodes[0].op_type != 'Add': return add_node = next_nodes[0] # Check Mul weight mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1]) if mul_weight_node.op_type != 'Constant': return weight_size, mul_weight = helper.constant_to_list(mul_weight_node) for i in mul_weight: if i != 1: return channel = weight_size[0] # Check Add weight add_weight_node = helper.find_node_by_output_name(g, add_node.input[1]) if add_weight_node.op_type != 'Constant': return # Check MatMul weight to see if it need weight broadcast matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1]) matmul_weight = helper.constant_to_numpy(matmul_weight_node) if matmul_weight.shape[1] == 1: # Weight broadcast new_matmul_weight = np.tile(matmul_weight, channel) new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight) g.node.remove(matmul_weight_node) g.node.extend([new_matmul_weight_node]) value = helper.find_value_by_name(g, matmul_weight_node.output[0]) if value is not None: g.value_info.remove(value) # Remove Mul node g.node.remove(mul_weight_node) value = helper.find_value_by_name(g, mul_weight_node.output[0]) if value is not None: g.value_info.remove(value) g.node.remove(mul_node) value = helper.find_value_by_name(g, mul_node.output[0]) if value is not None: g.value_info.remove(value) # Fuse Matmul and Add gemm_node = onnx.helper.make_node( 'Gemm', [matmul_node.input[0], matmul_node.input[1], add_node.input[1]], [add_node.output[0]], name = matmul_node.name, alpha = 1.0, beta = 1.0, transA = 0, transB = 0 ) g.node.extend([gemm_node]) # Clean up g.node.remove(matmul_node) g.node.remove(add_node) value = helper.find_value_by_name(g, matmul_node.output[0]) if value is not None: g.value_info.remove(value) other.topological_sort(g)
def fuse_Gemm_into_Gemm(g): """Fuse the previous Gemm into the following Gemm. :param g: the graph """ node_to_remove = [] for node in g.node: # Check for Gemm and Gemm if node.op_type != 'Gemm': continue prev_node = helper.find_node_by_output_name(g, node.input[0]) if prev_node is None: continue if prev_node.op_type != 'Gemm': continue # Get original weights prev_b_node = helper.find_node_by_output_name(g, prev_node.input[1]) prev_b = helper.constant_to_numpy(prev_b_node) prev_c_node = helper.find_node_by_output_name(g, prev_node.input[2]) prev_c = helper.constant_to_numpy(prev_c_node) b_node = helper.find_node_by_output_name(g, node.input[1]) b = helper.constant_to_numpy(b_node) c_node = helper.find_node_by_output_name(g, node.input[2]) c = helper.constant_to_numpy(c_node) # Apply attributes # alpha alpha = helper.get_attribute_by_name(node, 'alpha') if alpha is None: alpha = 1 else: alpha = alpha.f b = b * alpha alpha = helper.get_attribute_by_name(prev_node, 'alpha') if alpha is None: alpha = 1 else: alpha = alpha.f prev_b = prev_b * alpha # beta beta = helper.get_attribute_by_name(node, 'beta') if beta is None: beta = 1 else: beta = beta.f c = c * beta beta = helper.get_attribute_by_name(prev_node, 'beta') if beta is None: beta = 1 else: beta = beta.f prev_c = prev_c * beta # transA transA = helper.get_attribute_by_name(node, 'transA') if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") transA = helper.get_attribute_by_name(prev_node, 'transA') if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") # transB transB = helper.get_attribute_by_name(node, 'transB') if transB is not None and transB.i == 1: b = b.transpose() transB = helper.get_attribute_by_name(prev_node, 'transB') if transB is not None and transB.i == 1: prev_b = prev_b.transpose() # Calculate new weights new_b = prev_b.dot(b) new_c = prev_c.dot(b) + c # Replace original weights new_b_node = helper.numpy_to_constant(b_node.name + '_fused', new_b) new_c_node = helper.numpy_to_constant(c_node.name + '_fused', new_c) g.node.extend([new_b_node, new_c_node]) node_to_remove.extend( [b_node, c_node, prev_b_node, prev_c_node, prev_node]) # Modify attributes # alpha alpha = helper.get_attribute_by_name(node, 'alpha') if alpha is not None: alpha.f = 1.0 # beta beta = helper.get_attribute_by_name(node, 'beta') if beta is not None: beta.f = 1.0 # transB transB = helper.get_attribute_by_name(node, 'transB') if transB is not None: transB.i = 0 # Connect the new graph node.input[0] = prev_node.input[0] prev_value = helper.find_value_by_name(g, prev_node.output[0]) g.value_info.remove(prev_value) for i in range(1, 3): value = helper.find_value_by_name(g, prev_node.input[i]) g.value_info.remove(value) value = helper.find_value_by_name(g, node.input[i]) g.value_info.remove(value) node.input[1] = new_b_node.output[0] node.input[2] = new_c_node.output[0] # Remove useless nodes for node in node_to_remove: g.node.remove(node) topological_sort(g)
def fuse_mul_and_add_into_gemm(g): node_to_del = [] for node in g.node: if node.op_type != 'Add': continue add_node = node mul_node = helper.find_node_by_output_name(g, add_node.input[0]) if not mul_node or mul_node.op_type != 'Mul': continue mul_const = helper.find_node_by_output_name(g, mul_node.input[1]) if not mul_const or mul_const.op_type != 'Constant': continue add_const = helper.find_node_by_output_name(g, add_node.input[1]) if not add_const or add_const.op_type != 'Constant': continue input_val = helper.find_value_by_name(g, mul_node.input[0]) if not input_val: input_val = helper.find_input_by_name(g, mul_node.input[0]) if not input_val: continue _, input_shape = helper.find_size_shape_from_value(input_val) if not input_shape: continue dim = int(np.prod(input_shape)) if input_shape != [1, dim]: continue mul_const_shape, mul_const_data = helper.constant_to_list(mul_const) add_const_shape, __ = helper.constant_to_list(add_const) if len(mul_const_shape) != 1 or mul_const_shape[0] != dim: continue if len(add_const_shape) != 1 or add_const_shape[0] != dim: continue b_data = np.zeros([dim, dim]) for i in range(dim): b_data[i][i] = mul_const_data[i] b_data = b_data.flatten().tolist() b_tensor = onnx.helper.make_tensor( name=mul_const.name + '_tensor', data_type=mul_const.attribute[0].t.data_type, dims=[dim, dim], vals=b_data) b_const_node = onnx.helper.make_node('Constant', [], [mul_const.output[0]], value=b_tensor, name=mul_const.output[0]) add_const.attribute[0].t.dims.insert(0, 1) gemm_node = onnx.helper.make_node( 'Gemm', [mul_node.input[0], b_const_node.output[0], add_const.output[0]], [add_node.output[0]], name=add_node.output[0]) g.node.extend([gemm_node, b_const_node]) node_to_del.extend([mul_const, mul_node, add_node]) val_info_mid = helper.find_value_by_name(g, mul_node.output[0]) val_info_mul_const = helper.find_value_by_name(g, mul_const.output[0]) val_info_add_const = helper.find_value_by_name(g, add_const.output[0]) if val_info_mid: g.value_info.remove(val_info_mid) if val_info_mul_const: g.value_info.remove(val_info_mul_const) if val_info_add_const: g.value_info.remove(val_info_add_const) while node_to_del: g.node.remove(node_to_del.pop()) topological_sort(g)