def add_bn_on_skip_branch(g): for n in g.node: # Find merge node (Add) if n.op_type != 'Add': continue if len(n.input) != 2: continue # TODO: Still need to consider more cases # Check if skip branch exist input_node_a = helper.find_node_by_output_name(g, n.input[0]) output_of_input_node_a = helper.find_nodes_by_input_name( g, input_node_a.output[0]) input_node_b = helper.find_node_by_output_name(g, n.input[1]) output_of_input_node_b = helper.find_nodes_by_input_name( g, input_node_b.output[0]) if len(output_of_input_node_a) == 1 and len( output_of_input_node_b) == 1: continue if len(output_of_input_node_a) == 2: split_node = input_node_a elif len(output_of_input_node_b) == 2: split_node = input_node_b else: continue # Get the channel number from value info value_name = split_node.output[0] value = helper.find_value_by_name(g, value_name) shape = helper.get_shape_from_value_info(value) channel = shape[1] # Construct 4 weights node_name = value_name + "_nop_bn" ones = [1.0] * channel zeros = [0.0] * channel scale_node = helper.list_to_constant(node_name + "_scale", [channel], ones) bias_node = helper.list_to_constant(node_name + "_bias", [channel], zeros) mean_node = helper.list_to_constant(node_name + "_mean", [channel], zeros) var_node = helper.list_to_constant(node_name + "_var", [channel], ones) # Construct BN node bn_node = onnx.helper.make_node("BatchNormalization", [ value_name, scale_node.output[0], bias_node.output[0], mean_node.output[0], var_node.output[0] ], [node_name], name=node_name) # Reconnect the graph replace_node_input(n, value_name, node_name) # Add node to the graph g.node.extend([bn_node, scale_node, bias_node, mean_node, var_node]) topological_sort(g)
def replace_initializer_with_Constant(g): """ Replace initializers with Constant and a corresponding value_info If the initializer has related input, remove it. :param g: the onnx graph """ input_map = {i.name: i for i in g.input} for tensor in g.initializer: # Check for the initializer related input and remove it if tensor.name in input_map: value_info = input_map[tensor.name] g.input.remove(value_info) following_nodes = helper.find_nodes_by_input_name(g, tensor.name) for i, node in enumerate(following_nodes): new_name = tensor.name + "_duplicated_No" + str( i) if i > 0 else tensor.name modhelper.replace_node_input(node, tensor.name, new_name) new_node = onnx.helper.make_node("Constant", [], [new_name], name=new_name, value=tensor) # Add node to lists g.node.extend([new_node]) # if value info already exists, remove it as well. value_info = helper.find_value_by_name(g, tensor.name) if value_info is not None: g.value_info.remove(value_info) # Remove original initializer while len(g.initializer) != 0: g.initializer.pop() topological_sort(g)
def find_first_sequential_outputs(g, node): for value_name in node.output: value = helper.find_output_by_name(g, value_name) if value is not None: return value return find_first_sequential_outputs( g, helper.find_nodes_by_input_name(g, node.output[0])[0])
def replace_initializer_with_Constant(g): """ Replace initializers with Constant and a corresponding value_info :param g: the onnx graph """ # Creat a set of existed node names node_names = set() for node in g.node: node_names.add(node.name) # Unused initializers should be removed unused_initializer = set() for tensor in g.initializer: unused_initializer.add(tensor.name) for node in g.node: for in_value in node.input: if in_value in unused_initializer: unused_initializer.remove(in_value) input_map = {i.name: i for i in g.input} for tensor in g.initializer: if tensor.name in unused_initializer: value_info = input_map[tensor.name] g.input.remove(value_info) continue # Convert init to a constant node if tensor.name not in node_names: new_name = tensor.name else: new_name = tensor.name + '_2' following_nodes = helper.find_nodes_by_input_name(g, tensor.name) for node in following_nodes: modhelper.replace_node_input(node, tensor.name, new_name) new_node = onnx.helper.make_node( "Constant", [], [new_name], name=new_name, value=tensor ) # Add node to lists g.node.extend([new_node]) # Add value info to lists value_info = input_map[tensor.name] g.value_info.extend([value_info]) # Remove original input value info g.input.remove(value_info) # if value info exists, remove it as well. value_info = helper.find_value_by_name(g, tensor.name) if value_info is not None: g.value_info.remove(value_info) # Remove original initializer while len(g.initializer) != 0: g.initializer.pop()
def rename_output_name(g, original_name, new_name): # Output output_value = helper.find_output_by_name(g, original_name) if output_value is None: logging.error("Cannot find output value named " + original_name) return output_value.name = new_name # Value Info value_info = helper.find_value_by_name(g, original_name) if value_info is not None: value_info.name = new_name # Node output node = helper.find_node_by_output_name(g, original_name) node.output[0] = new_name # Node input nodes = helper.find_nodes_by_input_name(g, original_name) for node in nodes: replace_node_input(node, original_name, new_name)
def pattern_matmul_mul_add(g, matmul_node): # Check node match - Mul node next_nodes = helper.find_nodes_by_input_name(g, matmul_node.output[0]) if len(next_nodes) != 1: return if next_nodes[0].op_type != 'Mul': return mul_node = next_nodes[0] # Check node match - Add node next_nodes = helper.find_nodes_by_input_name(g, mul_node.output[0]) if len(next_nodes) != 1: return if next_nodes[0].op_type != 'Add': return add_node = next_nodes[0] # Check Mul weight mul_weight_node = helper.find_node_by_output_name(g, mul_node.input[1]) if mul_weight_node.op_type != 'Constant': return weight_size, mul_weight = helper.constant_to_list(mul_weight_node) for i in mul_weight: if i != 1: return channel = weight_size[0] # Check Add weight add_weight_node = helper.find_node_by_output_name(g, add_node.input[1]) if add_weight_node.op_type != 'Constant': return # Check MatMul weight to see if it need weight broadcast matmul_weight_node = helper.find_node_by_output_name(g, matmul_node.input[1]) matmul_weight = helper.constant_to_numpy(matmul_weight_node) if matmul_weight.shape[1] == 1: # Weight broadcast new_matmul_weight = np.tile(matmul_weight, channel) new_matmul_weight_node = helper.numpy_to_constant(matmul_weight_node.name, new_matmul_weight) g.node.remove(matmul_weight_node) g.node.extend([new_matmul_weight_node]) value = helper.find_value_by_name(g, matmul_weight_node.output[0]) if value is not None: g.value_info.remove(value) # Remove Mul node g.node.remove(mul_weight_node) value = helper.find_value_by_name(g, mul_weight_node.output[0]) if value is not None: g.value_info.remove(value) g.node.remove(mul_node) value = helper.find_value_by_name(g, mul_node.output[0]) if value is not None: g.value_info.remove(value) # Fuse Matmul and Add gemm_node = onnx.helper.make_node( 'Gemm', [matmul_node.input[0], matmul_node.input[1], add_node.input[1]], [add_node.output[0]], name = matmul_node.name, alpha = 1.0, beta = 1.0, transA = 0, transB = 0 ) g.node.extend([gemm_node]) # Clean up g.node.remove(matmul_node) g.node.remove(add_node) value = helper.find_value_by_name(g, matmul_node.output[0]) if value is not None: g.value_info.remove(value) other.topological_sort(g)