def inference_cov_shape(g): processed = False for node in g.node: if node.op_type != 'Conv': continue input_value_info = helper.find_value_by_name(g, node.input[0]) if not input_value_info: input_value_info = helper.find_input_by_name(g, node.input[0]) if not input_value_info: continue kernel_value_info = helper.find_value_by_name(g, node.input[1]) output_value_info = helper.find_value_by_name(g, node.output[0]) if not output_value_info: output_value_info = helper.find_output_by_name(g, node.output[0]) if output_value_info and \ helper.get_shape_from_value_info(output_value_info): continue _, kernel_shape = helper.find_size_shape_from_value(kernel_value_info) _, input_shape = helper.find_size_shape_from_value(input_value_info) if not input_shape or not kernel_shape: continue strides = helper.get_attribute_by_name(node, 'strides').ints pads = helper.get_attribute_by_name(node, 'pads').ints dilation = helper.get_attribute_by_name(node, 'dilations').ints # Pytorch model has the case where strides only have one number if len(strides) == 1: return strides.append(strides[0]) if len(dilation) == 1: return dilation.append(dilation[0]) H = math.floor((input_shape[2]+pads[0]+pads[2]-\ dilation[0]*(kernel_shape[2]-1)-1)/strides[0]+1) W = math.floor((input_shape[3]+pads[1]+pads[3]-\ dilation[1]*(kernel_shape[3]-1)-1)/strides[1]+1) output_shape = [input_shape[0], kernel_shape[0], H, W] new_output_value_info = onnx.helper.make_tensor_value_info( node.output[0], input_value_info.type.tensor_type.elem_type, output_shape) processed = True if output_value_info: g.value_info.remove(output_value_info) g.value_info.extend([new_output_value_info]) return processed
def set_upsample_mode_to_align_corner(g): """Set all the upsample nodes mode to align_corner. """ for node in g.node: if node.op_type != 'Upsample': continue # Find a upsample node attribute = helper.get_attribute_by_name(node, "mode") if type(attribute.s) == type(b'abc'): attribute.s = "align_corner".encode('utf-8') else: attribute.s = "align_corner"
def replace_ConstantOfShape_with_constant(g): """Replace Shape with Constant.\\ This is the first step of reshape constant folding. :param g: the input graph\\ :return: if anything modified, return true. """ node_to_remove = [] for node in g.node: # Find a Shape if node.op_type != 'ConstantOfShape': continue # Check input input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: input_value = helper.find_input_by_name(g, node.input[0]) if input_value is None or len( input_value.type.tensor_type.shape.dim) == 0: continue # Replace to constant node pre_node = helper.find_node_by_output_name(g, node.input[0]) _, target_shape = helper.constant_to_list(pre_node) value = helper.get_attribute_by_name(node, 'value').i node_name = node.output[0] new_node = helper.list_to_constant(node_name, [target_shape[0]], [value] * target_shape[0]) g.node.extend([new_node]) # remove old node node_to_remove.append(node) # delete value_info val_info_used = sum( [input_value.name in node.input for node in g.node]) if val_info_used == 1: g.value_info.remove(input_value) replaced = True if len(node_to_remove) > 0 else False for node in node_to_remove: g.node.remove(node) topological_sort(g) return replaced
def split_ConvTranspose(model): """To feed our compiler, split ConvTranspose into Upsample and Conv. :param model: the model """ node_to_delete = [] # Change model properties for upsample. if model.ir_version < 3: print("Warning: Current model IR version is not fully supported.") model.ir_version = 4 model.opset_import[0].version = 9 g = model.graph # Get a Convtranspose layer for node in g.node: # Find a Flatten node if node.op_type != 'ConvTranspose': continue # Check auto_pad auto_pad_proto = helper.get_attribute_by_name(node, "auto_pad") if auto_pad_proto is not None: print("Currently not split auto_pad ConvTranspose") continue # Check output_shape output_shape_proto = helper.get_attribute_by_name(node, "output_shape") if output_shape_proto is not None: print("Currently not split output_shape ConvTranspose") continue # Get input shape input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: input_value = helper.find_input_by_name(g, node.input[0]) if input_value is None: print("Cannot get value info named {}.".format(node.input[0])) exit(1) input_shape = helper.get_shape_from_value_info(input_value) # Get attrbutes attr = deconv_to_conv_info_extraction(input_shape, node) # Generate Upsample scales upsample_output_shape = list(input_shape) upsample_output_shape[2] = (input_shape[2] - 1) * attr["strides"][0] + 1 upsample_output_shape[3] = (input_shape[3] - 1) * attr["strides"][1] + 1 upsample_node_name = node.name + "_inner_upsample" upsample_scale_name = upsample_node_name + "_scales" scales_np = np.ones([4]).astype('float32') scales_np[2] = float(upsample_output_shape[2]) / input_shape[2] scales_np[3] = float(upsample_output_shape[3]) / input_shape[3] scales_node = helper.numpy_to_constant(upsample_scale_name, scales_np) # Generate a Upsample layer and an internal value info upsample_node = onnx.helper.make_node( "Upsample", [node.input[0], upsample_scale_name], [upsample_node_name], name=upsample_node_name, mode="zeros" ) upsample_value_info = onnx.helper.make_tensor_value_info( upsample_node_name, input_value.type.tensor_type.elem_type, upsample_output_shape ) # Check the weight layer, it may need a transpose if attr["group"] != input_shape[1]: weight_node = helper.find_node_by_output_name(g, node.input[1]) weight_np = helper.constant_to_numpy(weight_node) new_weight_np = np.transpose(weight_np, [1, 0, 2, 3]) new_weight_node = helper.numpy_to_constant(node.input[1], new_weight_np) node_to_delete.append(weight_node) g.node.extend([new_weight_node]) value = helper.find_value_by_name(g, node.input[1]) g.value_info.remove(value) # Generate a Conv layer conv_node_name = node.name + "_inner_conv" conv_node_input = [upsample_node_name] conv_node_input.extend(node.input[1:]) conv_node = onnx.helper.make_node( "Conv", conv_node_input, [node.output[0]], name=conv_node_name, pads=[int(i) for i in attr["conv_pads"]], dilations=[int(i) for i in attr["dilations"]], group=int(attr["group"]), kernel_shape=[int(i) for i in attr["kernel_shape"]], strides=[int(1), int(1)] ) # Reconnect the graph g.node.extend([scales_node, upsample_node, conv_node]) g.value_info.extend([upsample_value_info]) node_to_delete.append(node) # Delete useless nodes for node in node_to_delete: g.node.remove(node) topological_sort(g)
def inference_cov_shape(g): processed = False for node in g.node: # Check for Conv output shape need to be inferrenced. if node.op_type != 'Conv': continue # Input shape is not ready yet. Skip. input_value_info = helper.find_value_by_name(g, node.input[0]) if not input_value_info: input_value_info = helper.find_input_by_name(g, node.input[0]) if not input_value_info: continue _, input_shape = helper.find_size_shape_from_value(input_value_info) if not input_shape: continue # Output shape is already there. Skip. output_value_info = helper.find_value_by_name(g, node.output[0]) if not output_value_info: output_value_info = helper.find_output_by_name(g, node.output[0]) if output_value_info and \ helper.get_shape_from_value_info(output_value_info): continue # Now start the inference. # If auto_pad is set, use the auto_pad. auto_pad = helper.get_var_attribute_by_name(node, 'auto_pad', 'string') pads = None if auto_pad is not None and auto_pad != 'NOTSET': if auto_pad == 'SAME_LOWER' or auto_pad == 'SAME_UPPER': new_output_value_info = onnx.helper.make_tensor_value_info( node.output[0], input_value_info.type.tensor_type.elem_type, input_shape ) if output_value_info: g.value_info.remove(output_value_info) g.value_info.extend([new_output_value_info]) processed = True continue elif auto_pad == 'VALID': pads = [0, 0, 0, 0] else: print("Unrecognized auto_pad value: " + str(auto_pad)) exit(1) kernel_value_info = helper.find_value_by_name(g, node.input[1]) _, kernel_shape = helper.find_size_shape_from_value(kernel_value_info) if not input_shape or not kernel_shape: continue strides = helper.get_attribute_by_name(node, 'strides').ints if not pads: pads = helper.get_attribute_by_name(node, 'pads').ints dilation = helper.get_attribute_by_name(node, 'dilations').ints # Pytorch model has the case where strides only have one number if len(strides) == 1: return strides.append(strides[0]) if len(dilation) == 1: return dilation.append(dilation[0]) H = math.floor((input_shape[2]+pads[0]+pads[2]-\ dilation[0]*(kernel_shape[2]-1)-1)/strides[0]+1) W = math.floor((input_shape[3]+pads[1]+pads[3]-\ dilation[1]*(kernel_shape[3]-1)-1)/strides[1]+1) output_shape = [input_shape[0], kernel_shape[0], H, W] new_output_value_info = onnx.helper.make_tensor_value_info( node.output[0], input_value_info.type.tensor_type.elem_type, output_shape ) processed = True if output_value_info: g.value_info.remove(output_value_info) g.value_info.extend([new_output_value_info]) return processed
def fuse_BN_into_Gemm(g): """Fuse the following BN into the previous Gemm. :param g: the graph """ node_to_remove = [] for node in g.node: # Check for BN and Gemm if node.op_type != 'BatchNormalization': continue gemm_node = helper.find_node_by_output_name(g, node.input[0]) if gemm_node is None: continue if gemm_node.op_type != 'Gemm': continue if len( helper.find_following_nodes_by_input_value_name( g, gemm_node.output[0])) > 1: continue bn_node = node # Get original weights gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1]) gemm_b = helper.constant_to_numpy(gemm_b_node) gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2]) gemm_c = helper.constant_to_numpy(gemm_c_node) bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1]) bn_scale = helper.constant_to_numpy(bn_scale_node) bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2]) bn_bias = helper.constant_to_numpy(bn_bias_node) bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3]) bn_mean = helper.constant_to_numpy(bn_mean_node) bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4]) bn_var = helper.constant_to_numpy(bn_var_node) # Apply attributes # epsilon epsilon = helper.get_attribute_by_name(bn_node, 'epsilon') if epsilon is None: epsilon = 0.00001 else: epsilon = epsilon.f bn_var = bn_var + epsilon # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is None: alpha = 1 else: alpha = alpha.f gemm_b = gemm_b * alpha # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is None: beta = 1 else: beta = beta.f gemm_c = gemm_c * beta # transA transA = helper.get_attribute_by_name(gemm_node, 'transA') if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None and transB.i == 1: gemm_b = gemm_b.transpose() # Calculate new weights new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var) new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias # Replace original weights new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused', new_gemm_b) new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused', new_gemm_c) g.node.extend([new_gemm_b_node, new_gemm_c_node]) node_to_remove.extend([ gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node, bn_mean_node, bn_var_node ]) # Modify attributes # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is not None: alpha.f = 1.0 # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is not None: beta.f = 1.0 # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None: transB.i = 0 # Connect the new graph gemm_node.input[1] = new_gemm_b_node.output[0] gemm_node.input[2] = new_gemm_c_node.output[0] gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0]) gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0]) gemm_b_value.name = new_gemm_b_node.output[0] gemm_c_value.name = new_gemm_c_node.output[0] gemm_value = helper.find_value_by_name(g, gemm_node.output[0]) g.value_info.remove(gemm_value) gemm_node.output[0] = bn_node.output[0] for i in range(1, 5): value = helper.find_value_by_name(g, bn_node.input[i]) g.value_info.remove(value) # Remove useless nodes for node in node_to_remove: g.node.remove(node) topological_sort(g)
def fuse_Gemm_into_Gemm(g): """Fuse the previous Gemm into the following Gemm. :param g: the graph """ node_to_remove = [] for node in g.node: # Check for Gemm and Gemm if node.op_type != 'Gemm': continue prev_node = helper.find_node_by_output_name(g, node.input[0]) if prev_node is None: continue if prev_node.op_type != 'Gemm': continue # Get original weights prev_b_node = helper.find_node_by_output_name(g, prev_node.input[1]) prev_b = helper.constant_to_numpy(prev_b_node) prev_c_node = helper.find_node_by_output_name(g, prev_node.input[2]) prev_c = helper.constant_to_numpy(prev_c_node) b_node = helper.find_node_by_output_name(g, node.input[1]) b = helper.constant_to_numpy(b_node) c_node = helper.find_node_by_output_name(g, node.input[2]) c = helper.constant_to_numpy(c_node) # Apply attributes # alpha alpha = helper.get_attribute_by_name(node, 'alpha') if alpha is None: alpha = 1 else: alpha = alpha.f b = b * alpha alpha = helper.get_attribute_by_name(prev_node, 'alpha') if alpha is None: alpha = 1 else: alpha = alpha.f prev_b = prev_b * alpha # beta beta = helper.get_attribute_by_name(node, 'beta') if beta is None: beta = 1 else: beta = beta.f c = c * beta beta = helper.get_attribute_by_name(prev_node, 'beta') if beta is None: beta = 1 else: beta = beta.f prev_c = prev_c * beta # transA transA = helper.get_attribute_by_name(node, 'transA') if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") transA = helper.get_attribute_by_name(prev_node, 'transA') if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") # transB transB = helper.get_attribute_by_name(node, 'transB') if transB is not None and transB.i == 1: b = b.transpose() transB = helper.get_attribute_by_name(prev_node, 'transB') if transB is not None and transB.i == 1: prev_b = prev_b.transpose() # Calculate new weights new_b = prev_b.dot(b) new_c = prev_c.dot(b) + c # Replace original weights new_b_node = helper.numpy_to_constant(b_node.name + '_fused', new_b) new_c_node = helper.numpy_to_constant(c_node.name + '_fused', new_c) g.node.extend([new_b_node, new_c_node]) node_to_remove.extend( [b_node, c_node, prev_b_node, prev_c_node, prev_node]) # Modify attributes # alpha alpha = helper.get_attribute_by_name(node, 'alpha') if alpha is not None: alpha.f = 1.0 # beta beta = helper.get_attribute_by_name(node, 'beta') if beta is not None: beta.f = 1.0 # transB transB = helper.get_attribute_by_name(node, 'transB') if transB is not None: transB.i = 0 # Connect the new graph node.input[0] = prev_node.input[0] prev_value = helper.find_value_by_name(g, prev_node.output[0]) g.value_info.remove(prev_value) for i in range(1, 3): value = helper.find_value_by_name(g, prev_node.input[i]) g.value_info.remove(value) value = helper.find_value_by_name(g, node.input[i]) g.value_info.remove(value) node.input[1] = new_b_node.output[0] node.input[2] = new_c_node.output[0] # Remove useless nodes for node in node_to_remove: g.node.remove(node) topological_sort(g)
def fuse_BN_with_Reshape_into_Gemm(g): """Fuse the following BN into the previous Gemm, even with Reshape or \\ Squeeze and Unsqueeze surrounding. :param g: the graph """ node_to_remove = [] for node in g.node: # Check for BN and Gemm pattern: Gemm A BN B # Find BatchNorm Node if node.op_type != 'BatchNormalization': continue bn_node = node # Find A Node a_node = helper.find_node_by_output_name(g, node.input[0]) if a_node is None or len(a_node.input) == 0: continue # Find Gemm Node gemm_node = helper.find_node_by_output_name(g, a_node.input[0]) if gemm_node is None or gemm_node.op_type != 'Gemm': continue # Find B Node b_node_list = helper.find_following_nodes_by_input_value_name( g, bn_node.output[0]) if len(b_node_list) == 0: the_output = helper.find_output_by_name(g, bn_node.output[0]) if the_output is None: continue b_node = None elif len(b_node_list) > 1: continue else: b_node = b_node_list[0] # Check for branches if len( helper.find_following_nodes_by_input_value_name( g, gemm_node.output[0])) > 1: continue if len( helper.find_following_nodes_by_input_value_name( g, a_node.output[0])) > 1: continue # Check type of A if a_node.op_type == 'Unsqueeze': axes = helper.get_attribute_by_name(a_node, 'axes') if axes.ints != [2]: continue elif a_node.op_type == 'Reshape': a = helper.constant_to_list( helper.find_node_by_output_name(g, a_node.input[1]))[1] if len(a) != 3 or a[2] != 1: continue else: continue # Check type of B if b_node is None: pass elif b_node.op_type == 'Flatten': pass elif b_node.op_type == 'Squeeze': axes = helper.get_attribute_by_name(a_node, 'axes') if axes.ints != [2]: continue elif b_node.op_type == 'Reshape': a = helper.constant_to_list( helper.find_node_by_output_name(g, b_node.input[1]))[1] if len(a) != 2: continue else: continue # Construct new Nodes # Get original weights gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1]) gemm_b = helper.constant_to_numpy(gemm_b_node) gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2]) gemm_c = helper.constant_to_numpy(gemm_c_node) bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1]) bn_scale = helper.constant_to_numpy(bn_scale_node) bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2]) bn_bias = helper.constant_to_numpy(bn_bias_node) bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3]) bn_mean = helper.constant_to_numpy(bn_mean_node) bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4]) bn_var = helper.constant_to_numpy(bn_var_node) # Apply attributes # epsilon epsilon = helper.get_attribute_by_name(bn_node, 'epsilon') if epsilon is None: epsilon = 0.00001 else: epsilon = epsilon.f bn_var = bn_var + epsilon # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is None: alpha = 1 else: alpha = alpha.f gemm_b = gemm_b * alpha # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is None: beta = 1 else: beta = beta.f gemm_c = gemm_c * beta # transA transA = helper.get_attribute_by_name(gemm_node, 'transA') if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None and transB.i == 1: gemm_b = gemm_b.transpose() # Calculate new weights new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var) new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias # Replace original weights new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused', new_gemm_b) new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused', new_gemm_c) g.node.extend([new_gemm_b_node, new_gemm_c_node]) # Modify attributes # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is not None: alpha.f = 1.0 # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is not None: beta.f = 1.0 # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None: transB.i = 0 # Remove useless nodes node_to_remove.extend([ gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node, bn_mean_node, bn_var_node, a_node ]) if a_node.op_type == 'Reshape': node_to_remove.append( helper.find_node_by_output_name(g, a_node.input[1])) if b_node is not None: node_to_remove.append(b_node) if b_node.op_type == 'Reshape': node_to_remove.append( helper.find_node_by_output_name(g, b_node.input[1])) # Delete useless value infos value = helper.find_value_by_name(g, a_node.output[0]) g.value_info.remove(value) if a_node.op_type == 'Reshape': value = helper.find_value_by_name(g, a_node.input[1]) g.value_info.remove(value) for i in range(1, 5): value = helper.find_value_by_name(g, bn_node.input[i]) g.value_info.remove(value) value = helper.find_value_by_name(g, bn_node.output[0]) if value is not None: g.value_info.remove(value) if b_node is not None: value = helper.find_value_by_name(g, gemm_node.output[0]) g.value_info.remove(value) if b_node.op_type == 'Reshape': value = helper.find_value_by_name(g, b_node.input[1]) g.value_info.remove(value) # Connect the new graph # Connect Gemm new weights gemm_node.input[1] = new_gemm_b_node.output[0] gemm_node.input[2] = new_gemm_c_node.output[0] gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0]) gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0]) gemm_b_value.name = new_gemm_b_node.output[0] gemm_b_value.type.tensor_type.shape.dim[ 0].dim_value = new_gemm_b.shape[0] gemm_b_value.type.tensor_type.shape.dim[ 1].dim_value = new_gemm_b.shape[1] gemm_c_value.name = new_gemm_c_node.output[0] if b_node is None: # If b node is None, set the Gemm output as the graph output output_value = helper.find_output_by_name(g, bn_node.output[0]) g.output.remove(output_value) g.output.extend( [helper.find_value_by_name(g, gemm_node.output[0])]) else: # Else, set node B output as gemm output gemm_node.output[0] = b_node.output[0] # Remove useless nodes for node in node_to_remove: g.node.remove(node) topological_sort(g)