def inference_resize_shape(g): for node in g.node: if node.op_type != 'Resize': continue output_value = helper.find_value_by_name(g, node.output[0]) output_value = helper.find_output_by_name( g, node.output[0]) if output_value is None else output_value if output_value is not None: continue # currently, only support 4 input if len(node.input) == 4: # input: X, roi, scales, sizes shape_node = helper.find_node_by_output_name(g, node.input[3]) if shape_node.op_type != 'Constant': continue _, shape_value = helper.constant_to_list(shape_node) output_value = onnx.helper.make_tensor_value_info( node.output[0], onnx.TensorProto.FLOAT, [int(v) for v in shape_value]) g.value_info.extend([output_value]) return True return False
def replace_Squeeze_with_Reshape(g): """ Replace Squeeze nodes with Reshape node. :param g: the input graph """ node_to_remove = [] for node in g.node: # Find Squeeze node if node.op_type != 'Squeeze': continue # Get the shape and Construct the shape output_value = helper.find_value_by_name(g, node.output[0]) if output_value is None: output_value = helper.find_output_by_name(g, node.output[0]) if output_value is None: raise RuntimeError("Cannot get shape for Squeeze") shape = [ dim.dim_value for dim in output_value.type.tensor_type.shape.dim ] const_node = helper.list_to_constant(node.name + "_shape", [len(shape)], shape) # Construct the Reshape layer with same input, output and name. new_node = onnx.helper.make_node("Reshape", [node.input[0], node.name + "_shape"], node.output, name=node.name) # Append constructed nodes and append old node to remove_list g.node.extend([const_node, new_node]) node_to_remove.append(node) # Remove old nodes for node in node_to_remove: g.node.remove(node) # Topological sort topological_sort(g)
def add_nop_bn_after(g, value_names): """Add do-nothing BatchNormalization nodes after the given value info. It will\\ take the given names as the inputs of the new node and replace the inputs\\ of the following nodes. :param g: the graph\\ :param value_names: a list of string which are the names of value_info. """ for value_name in value_names: # Find the value first value = helper.find_value_by_name(g, value_name) if value is None: value = helper.find_input_by_name(g, value_name) if value is None: value = helper.find_output_by_name(g, value_name) if value is None: print("Cannot find an value_info named {}".format(value_name)) continue # Get the channel number from value info shape = helper.get_shape_from_value_info(value) channel = shape[1] # Construct 4 weights node_name = value_name + "_nop_bn" ones = [1.0] * channel zeros = [0.0] * channel scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) var_node = helper.list_to_constant(node_name + "_var", [channel], ones) # Construct BN node bn_node = onnx.helper.make_node("BatchNormalization", [ value_name, scale_node.output[0], bias_node.output[0], mean_node.output[0], var_node.output[0] ], [node_name], name=node_name) # Reconnect the graph following_nodes = helper.find_following_nodes_by_input_value_name( g, value_name) if len(following_nodes) > 0: for following_node in following_nodes: replace_node_input(following_node, value_name, node_name) else: # If the node is the output, replace the output with the previous input. new_value = onnx.helper.make_tensor_value_info( node_name, value.type.tensor_type.elem_type, shape) output_values = [] while len(g.output): output_values.append(g.output.pop()) while output_values: output_value = output_values.pop() if output_value.name == value_name: g.output.extend([new_value]) else: g.output.extend([output_value]) # Add node to the graph g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) topological_sort(g)
def rename_all_node_name(g): """ rename all nodes: new_name = old_name + "_kn" :param g: the onnx graph """ for node in g.node: new_node_name = node.name + "_kn" new_node_output0_name = node.output[0] + "_kn" # in order to keep same output node name, skip if it is output node. output_value_info = helper.find_output_by_name(g, node.output[0]) if output_value_info != None: continue # rename the input of all the following nodes following_nodes = helper.find_following_nodes_by_input_value_name( g, node.output[0]) for following_node in following_nodes: replace_node_input(following_node, node.output[0], new_node_output0_name) # rename value info value_info = helper.find_value_by_name(g, node.output[0]) if value_info != None: value_info.name = new_node_output0_name # rename node node.output[0] = new_node_output0_name node.name = new_node_name
def find_first_sequential_outputs(g, node): for value_name in node.output: value = helper.find_output_by_name(g, value_name) if value is not None: return value return find_first_sequential_outputs( g, helper.find_nodes_by_input_name(g, node.output[0])[0])
def add_nop_conv_after(g, value_names): """Add do-nothing depthwise Conv nodes after the given value info. It will\\ take the given names as the inputs of the new node and replace the inputs\\ of the following nodes. :param g: the graph\\ :param value_names: a list of string which are the names of value_info. """ for value_name in value_names: # Find the value first value = helper.find_value_by_name(g, value_name) if value is None: value = helper.find_input_by_name(g, value_name) if value is None: value = helper.find_output_by_name(g, value_name) if value is None: print("Cannot find an value_info named {}".format(value_name)) continue # Get the channel number from value info shape = helper.get_shape_from_value_info(value) channel = shape[1] # Construct 4 weights node_name = value_name + "_nop_conv" ones = [1.0] * channel weight_node = helper.list_to_constant(node_name + "_weight", [channel, 1, 1, 1], ones) # Construct BN node conv_node = onnx.helper.make_node("Conv", [value_name, weight_node.output[0]], [node_name], name=node_name, dilations=[1, 1], group=channel, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[1, 1]) # Reconnect the graph following_nodes = helper.find_following_nodes_by_input_value_name( g, value_name) if len(following_nodes) > 0: for following_node in following_nodes: replace_node_input(following_node, value_name, node_name) else: # If the node is the output, replace the output with the previous input. new_value = onnx.helper.make_tensor_value_info( node_name, value.type.tensor_type.elem_type, shape) output_values = [] while len(g.output): output_values.append(g.output.pop()) while output_values: output_value = output_values.pop() if output_value.name == value_name: g.output.extend([new_value]) else: g.output.extend([output_value]) # Add node to the graph g.node.extend([conv_node, weight_node]) topological_sort(g)
def inference_cov_shape(g): processed = False for node in g.node: if node.op_type != 'Conv': continue input_value_info = helper.find_value_by_name(g, node.input[0]) if not input_value_info: input_value_info = helper.find_input_by_name(g, node.input[0]) if not input_value_info: continue kernel_value_info = helper.find_value_by_name(g, node.input[1]) output_value_info = helper.find_value_by_name(g, node.output[0]) if not output_value_info: output_value_info = helper.find_output_by_name(g, node.output[0]) if output_value_info and \ helper.get_shape_from_value_info(output_value_info): continue _, kernel_shape = helper.find_size_shape_from_value(kernel_value_info) _, input_shape = helper.find_size_shape_from_value(input_value_info) if not input_shape or not kernel_shape: continue strides = helper.get_attribute_by_name(node, 'strides').ints pads = helper.get_attribute_by_name(node, 'pads').ints dilation = helper.get_attribute_by_name(node, 'dilations').ints # Pytorch model has the case where strides only have one number if len(strides) == 1: return strides.append(strides[0]) if len(dilation) == 1: return dilation.append(dilation[0]) H = math.floor((input_shape[2]+pads[0]+pads[2]-\ dilation[0]*(kernel_shape[2]-1)-1)/strides[0]+1) W = math.floor((input_shape[3]+pads[1]+pads[3]-\ dilation[1]*(kernel_shape[3]-1)-1)/strides[1]+1) output_shape = [input_shape[0], kernel_shape[0], H, W] new_output_value_info = onnx.helper.make_tensor_value_info( node.output[0], input_value_info.type.tensor_type.elem_type, output_shape) processed = True if output_value_info: g.value_info.remove(output_value_info) g.value_info.extend([new_output_value_info]) return processed
def rename_output_name(g, original_name, new_name): # Output output_value = helper.find_output_by_name(g, original_name) if output_value is None: logging.error("Cannot find output value named " + original_name) return output_value.name = new_name # Value Info value_info = helper.find_value_by_name(g, original_name) if value_info is not None: value_info.name = new_name # Node output node = helper.find_node_by_output_name(g, original_name) node.output[0] = new_name # Node input nodes = helper.find_nodes_by_input_name(g, original_name) for node in nodes: replace_node_input(node, original_name, new_name)
def inference_upsample_shape(g): """For onnx v1.4.1+, onnx cannot inference upsample output shape. Let's\\ do it ourselves. This function only inference the next upsample without\\ output shape each time. :param g: the graph\\ :return: True if any Upsample shape is generated. Otherwise, False. """ for node in g.node: if node.op_type != 'Upsample': continue output_value = helper.find_value_by_name(g, node.output[0]) if output_value is None: output_value = helper.find_output_by_name(g, node.output[0]) if output_value and helper.get_shape_from_value_info(output_value): continue # Get input shape input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: continue #raise RuntimeError("Shape for {} has not been generated.".format(node.input[0])) if not helper.get_shape_from_value_info(input_value): continue #raise RuntimeError("Shape for {} is empty.".format(node.input[0])) input_shape = helper.get_shape_from_value_info(input_value) # Get upsample weight weight_node = helper.find_node_by_output_name(g, node.input[1]) weight_shape, weight = helper.constant_to_list(weight_node) if len(input_shape) != weight_shape[0]: raise RuntimeError( "Unmatch input shape and weight shape: {} vs {}".format( input_shape, weight_shape)) # Calculate shape output_shape = list(input_shape) for i in range(len(output_shape)): output_shape[i] = int(input_shape[i] * weight[i]) output_value = onnx.helper.make_tensor_value_info( node.output[0], input_value.type.tensor_type.elem_type, output_shape) g.value_info.extend([output_value]) return True return False
def change_output_shape(g, target_list): for target in target_list: try: name, shape = parse_shape_change_input(target) output_value = helper.find_output_by_name(g, name) if output_value is None: print("Cannot find output {}".format(name)) continue if len(shape) != len(output_value.type.tensor_type.shape.dim): print("The dimension doesn't match for output {}".format(name)) continue for i in range(len(shape)): output_value.type.tensor_type.shape.dim[i].dim_value = shape[i] except TypeError: # This happens when the parser function returns None. continue except ValueError: # This happens when the input cannot be converter into int print("Cannot parse {} into name and int".format(target)) continue
def inference_resize_shape(g): for node in g.node: if node.op_type != 'Resize': continue output_value = helper.find_value_by_name(g, node.output[0]) output_value = helper.find_output_by_name( g, node.output[0]) if output_value is None else output_value if output_value is not None: continue if len(node.input) == 4: # input: X, roi, scales, sizes shape_node = helper.find_node_by_output_name(g, node.input[3]) if shape_node.op_type != 'Constant': continue _, shape_value = helper.constant_to_list(shape_node) output_value = onnx.helper.make_tensor_value_info( node.output[0], onnx.TensorProto.FLOAT, [int(v) for v in shape_value]) g.value_info.extend([output_value]) return True else: # If output shape is not given, inference from scales # Get the input shape input_value = helper.find_value_by_name(g, node.input[0]) if input_value is None: continue shape_value = helper.get_shape_from_value_info(input_value) scales_node = helper.find_node_by_output_name(g, node.input[2]) if scales_node.op_type != 'Constant': continue _, scales_value = helper.constant_to_list(scales_node) for i in range(len(shape_value)): shape_value[i] *= scales_value[i] output_value = onnx.helper.make_tensor_value_info( node.output[0], onnx.TensorProto.FLOAT, [int(v) for v in shape_value]) g.value_info.extend([output_value]) return True return False
def remove_nodes(g, cut_nodes=[], cut_types=[]): node_to_delete = [] #Find target nodes for node in g.node: if node.name not in cut_nodes and node.op_type not in cut_types: continue else: node_to_delete.append(node) # Mapping outputs output_mapping = {} new_output = set() for node in node_to_delete: original_output = find_first_sequential_outputs(g, node) if original_output.name not in output_mapping: output_mapping[original_output.name] = [] for input_name in node.input: value = helper.find_value_by_name(g, input_name) if value is not None and helper.find_output_by_name(g, input_name) is None and value.name not in new_output: output_mapping[original_output.name].append(value) new_output.add(value.name) # Remove them while node_to_delete: g.node.remove(node_to_delete.pop()) # Remove unreachable nodes visited_values = set() unused_constant_map = {} for input_value in g.input: visited_values.add(input_value.name) for node in g.node: if node.op_type == 'Constant': visited_values.add(node.output[0]) unused_constant_map[node.output[0]] = node continue can_reach = True for input_name in node.input: if input_name not in visited_values: can_reach = False break if can_reach: for output_name in node.output: visited_values.add(output_name) else: node_to_delete.append(node) # Mapping outputs again for node in node_to_delete: original_output = find_first_sequential_outputs(g, node) if original_output.name not in output_mapping: output_mapping[original_output.name] = [] for input_name in node.input: value = helper.find_value_by_name(g, input_name) if value is not None and helper.find_output_by_name(g, input_name) is None and value.name not in new_output: output_mapping[original_output.name].append(value) new_output.add(value.name) # Remove them while node_to_delete: g.node.remove(node_to_delete.pop()) #Remove unused constants for node in g.node: for input_name in node.input: if input_name in unused_constant_map: del unused_constant_map[input_name] for node in unused_constant_map.values(): g.node.remove(node) #Remove unreachable value infos reachable_values = set() for input_value in g.input: reachable_values.add(input_value.name) for node in g.node: for input_name in node.input: reachable_values.add(input_name) for output_name in node.output: reachable_values.add(output_name) value_to_remove = [] for value_info in g.value_info: if value_info.name not in reachable_values: value_to_remove.append(value_info) while value_to_remove: value_info = value_to_remove.pop() g.value_info.remove(value_info) # Reorder output output_values = [] while len(g.output): output_values.append(g.output.pop()) while output_values: output_value = output_values.pop() if output_value.name in reachable_values: logging.info("Keep output {}".format(output_value.name)) g.output.extend([output_value]) elif output_value.name in output_mapping: real_outputs = [i for i in output_mapping[output_value.name] if i.name in reachable_values] logging.info("Replace output {} with {}".format(output_value.name, [i.name for i in real_outputs])) g.output.extend(real_outputs) else: logging.info("Abandon output {}".format(output_value.name)) continue
def inference_cov_shape(g): processed = False for node in g.node: # Check for Conv output shape need to be inferrenced. if node.op_type != 'Conv': continue # Input shape is not ready yet. Skip. input_value_info = helper.find_value_by_name(g, node.input[0]) if not input_value_info: input_value_info = helper.find_input_by_name(g, node.input[0]) if not input_value_info: continue _, input_shape = helper.find_size_shape_from_value(input_value_info) if not input_shape: continue # Output shape is already there. Skip. output_value_info = helper.find_value_by_name(g, node.output[0]) if not output_value_info: output_value_info = helper.find_output_by_name(g, node.output[0]) if output_value_info and \ helper.get_shape_from_value_info(output_value_info): continue # Now start the inference. # If auto_pad is set, use the auto_pad. auto_pad = helper.get_var_attribute_by_name(node, 'auto_pad', 'string') pads = None if auto_pad is not None and auto_pad != 'NOTSET': if auto_pad == 'SAME_LOWER' or auto_pad == 'SAME_UPPER': new_output_value_info = onnx.helper.make_tensor_value_info( node.output[0], input_value_info.type.tensor_type.elem_type, input_shape ) if output_value_info: g.value_info.remove(output_value_info) g.value_info.extend([new_output_value_info]) processed = True continue elif auto_pad == 'VALID': pads = [0, 0, 0, 0] else: print("Unrecognized auto_pad value: " + str(auto_pad)) exit(1) kernel_value_info = helper.find_value_by_name(g, node.input[1]) _, kernel_shape = helper.find_size_shape_from_value(kernel_value_info) if not input_shape or not kernel_shape: continue strides = helper.get_attribute_by_name(node, 'strides').ints if not pads: pads = helper.get_attribute_by_name(node, 'pads').ints dilation = helper.get_attribute_by_name(node, 'dilations').ints # Pytorch model has the case where strides only have one number if len(strides) == 1: return strides.append(strides[0]) if len(dilation) == 1: return dilation.append(dilation[0]) H = math.floor((input_shape[2]+pads[0]+pads[2]-\ dilation[0]*(kernel_shape[2]-1)-1)/strides[0]+1) W = math.floor((input_shape[3]+pads[1]+pads[3]-\ dilation[1]*(kernel_shape[3]-1)-1)/strides[1]+1) output_shape = [input_shape[0], kernel_shape[0], H, W] new_output_value_info = onnx.helper.make_tensor_value_info( node.output[0], input_value_info.type.tensor_type.elem_type, output_shape ) processed = True if output_value_info: g.value_info.remove(output_value_info) g.value_info.extend([new_output_value_info]) return processed
def fuse_BN_with_Reshape_into_Gemm(g): """Fuse the following BN into the previous Gemm, even with Reshape or \\ Squeeze and Unsqueeze surrounding. :param g: the graph """ node_to_remove = [] for node in g.node: # Check for BN and Gemm pattern: Gemm A BN B # Find BatchNorm Node if node.op_type != 'BatchNormalization': continue bn_node = node # Find A Node a_node = helper.find_node_by_output_name(g, node.input[0]) if a_node is None or len(a_node.input) == 0: continue # Find Gemm Node gemm_node = helper.find_node_by_output_name(g, a_node.input[0]) if gemm_node is None or gemm_node.op_type != 'Gemm': continue # Find B Node b_node_list = helper.find_following_nodes_by_input_value_name( g, bn_node.output[0]) if len(b_node_list) == 0: the_output = helper.find_output_by_name(g, bn_node.output[0]) if the_output is None: continue b_node = None elif len(b_node_list) > 1: continue else: b_node = b_node_list[0] # Check for branches if len( helper.find_following_nodes_by_input_value_name( g, gemm_node.output[0])) > 1: continue if len( helper.find_following_nodes_by_input_value_name( g, a_node.output[0])) > 1: continue # Check type of A if a_node.op_type == 'Unsqueeze': axes = helper.get_attribute_by_name(a_node, 'axes') if axes.ints != [2]: continue elif a_node.op_type == 'Reshape': a = helper.constant_to_list( helper.find_node_by_output_name(g, a_node.input[1]))[1] if len(a) != 3 or a[2] != 1: continue else: continue # Check type of B if b_node is None: pass elif b_node.op_type == 'Flatten': pass elif b_node.op_type == 'Squeeze': axes = helper.get_attribute_by_name(a_node, 'axes') if axes.ints != [2]: continue elif b_node.op_type == 'Reshape': a = helper.constant_to_list( helper.find_node_by_output_name(g, b_node.input[1]))[1] if len(a) != 2: continue else: continue # Construct new Nodes # Get original weights gemm_b_node = helper.find_node_by_output_name(g, gemm_node.input[1]) gemm_b = helper.constant_to_numpy(gemm_b_node) gemm_c_node = helper.find_node_by_output_name(g, gemm_node.input[2]) gemm_c = helper.constant_to_numpy(gemm_c_node) bn_scale_node = helper.find_node_by_output_name(g, bn_node.input[1]) bn_scale = helper.constant_to_numpy(bn_scale_node) bn_bias_node = helper.find_node_by_output_name(g, bn_node.input[2]) bn_bias = helper.constant_to_numpy(bn_bias_node) bn_mean_node = helper.find_node_by_output_name(g, bn_node.input[3]) bn_mean = helper.constant_to_numpy(bn_mean_node) bn_var_node = helper.find_node_by_output_name(g, bn_node.input[4]) bn_var = helper.constant_to_numpy(bn_var_node) # Apply attributes # epsilon epsilon = helper.get_attribute_by_name(bn_node, 'epsilon') if epsilon is None: epsilon = 0.00001 else: epsilon = epsilon.f bn_var = bn_var + epsilon # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is None: alpha = 1 else: alpha = alpha.f gemm_b = gemm_b * alpha # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is None: beta = 1 else: beta = beta.f gemm_c = gemm_c * beta # transA transA = helper.get_attribute_by_name(gemm_node, 'transA') if transA is not None and transA.i == 1: raise RuntimeError("Do not support transA") # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None and transB.i == 1: gemm_b = gemm_b.transpose() # Calculate new weights new_gemm_b = gemm_b * bn_scale / np.sqrt(bn_var) new_gemm_c = (gemm_c - bn_mean) * bn_scale / np.sqrt(bn_var) + bn_bias # Replace original weights new_gemm_b_node = helper.numpy_to_constant(gemm_b_node.name + '_fused', new_gemm_b) new_gemm_c_node = helper.numpy_to_constant(gemm_c_node.name + '_fused', new_gemm_c) g.node.extend([new_gemm_b_node, new_gemm_c_node]) # Modify attributes # alpha alpha = helper.get_attribute_by_name(gemm_node, 'alpha') if alpha is not None: alpha.f = 1.0 # beta beta = helper.get_attribute_by_name(gemm_node, 'beta') if beta is not None: beta.f = 1.0 # transB transB = helper.get_attribute_by_name(gemm_node, 'transB') if transB is not None: transB.i = 0 # Remove useless nodes node_to_remove.extend([ gemm_b_node, gemm_c_node, bn_node, bn_scale_node, bn_bias_node, bn_mean_node, bn_var_node, a_node ]) if a_node.op_type == 'Reshape': node_to_remove.append( helper.find_node_by_output_name(g, a_node.input[1])) if b_node is not None: node_to_remove.append(b_node) if b_node.op_type == 'Reshape': node_to_remove.append( helper.find_node_by_output_name(g, b_node.input[1])) # Delete useless value infos value = helper.find_value_by_name(g, a_node.output[0]) g.value_info.remove(value) if a_node.op_type == 'Reshape': value = helper.find_value_by_name(g, a_node.input[1]) g.value_info.remove(value) for i in range(1, 5): value = helper.find_value_by_name(g, bn_node.input[i]) g.value_info.remove(value) value = helper.find_value_by_name(g, bn_node.output[0]) if value is not None: g.value_info.remove(value) if b_node is not None: value = helper.find_value_by_name(g, gemm_node.output[0]) g.value_info.remove(value) if b_node.op_type == 'Reshape': value = helper.find_value_by_name(g, b_node.input[1]) g.value_info.remove(value) # Connect the new graph # Connect Gemm new weights gemm_node.input[1] = new_gemm_b_node.output[0] gemm_node.input[2] = new_gemm_c_node.output[0] gemm_b_value = helper.find_value_by_name(g, gemm_b_node.output[0]) gemm_c_value = helper.find_value_by_name(g, gemm_c_node.output[0]) gemm_b_value.name = new_gemm_b_node.output[0] gemm_b_value.type.tensor_type.shape.dim[ 0].dim_value = new_gemm_b.shape[0] gemm_b_value.type.tensor_type.shape.dim[ 1].dim_value = new_gemm_b.shape[1] gemm_c_value.name = new_gemm_c_node.output[0] if b_node is None: # If b node is None, set the Gemm output as the graph output output_value = helper.find_output_by_name(g, bn_node.output[0]) g.output.remove(output_value) g.output.extend( [helper.find_value_by_name(g, gemm_node.output[0])]) else: # Else, set node B output as gemm output gemm_node.output[0] = b_node.output[0] # Remove useless nodes for node in node_to_remove: g.node.remove(node) topological_sort(g)
def remove_nodes(g, cut_nodes=[], cut_types=[]): node_to_delete = [] #Find target nodes for node in g.node: if node.name not in cut_nodes and node.op_type not in cut_types: continue else: node_to_delete.append(node) #Remove them and add new outputs new_output = [] while node_to_delete: node = node_to_delete.pop() for input_name in node.input: value = helper.find_value_by_name(g, input_name) if value is not None and helper.find_output_by_name( g, input_name) is None: new_output.append(value) g.node.remove(node) g.output.extend(new_output) #Remove unreachable nodes visited_values = set() unused_constant_map = {} for input_value in g.input: visited_values.add(input_value.name) for node in g.node: if node.op_type == 'Constant': visited_values.add(node.output[0]) unused_constant_map[node.output[0]] = node continue can_reach = True for input_name in node.input: if input_name not in visited_values: can_reach = False break if can_reach: for output_name in node.output: visited_values.add(output_name) else: node_to_delete.append(node) new_output = [] while node_to_delete: node = node_to_delete.pop() for input_name in node.input: value = helper.find_value_by_name(g, input_name) if value is not None and helper.find_output_by_name( g, input_name) is None: new_output.append(value) g.node.remove(node) g.output.extend(new_output) #Remove unused constants for node in g.node: for input_name in node.input: if input_name in unused_constant_map: del unused_constant_map[input_name] for node in unused_constant_map.values(): g.node.remove(node) #Remove unreachable value infos and outputs reachable_values = set() for input_value in g.input: reachable_values.add(input_value.name) for node in g.node: for input_name in node.input: reachable_values.add(input_name) for output_name in node.output: reachable_values.add(output_name) value_to_remove = [] for value_info in g.value_info: if value_info.name not in reachable_values: value_to_remove.append(value_info) while value_to_remove: value_info = value_to_remove.pop() g.value_info.remove(value_info) for value_info in g.output: if value_info.name not in reachable_values: value_to_remove.append(value_info) while value_to_remove: value_info = value_to_remove.pop() g.output.remove(value_info)