def optimize_fp16_onnx_with_cast(input_onnx_path, optimized_onnx_path, epsilon): m = onnx.load(input_onnx_path) onnx_model = OnnxModel(m) nodes_to_remove = onnx_model.nodes() nodes_to_add = [ onnx.helper.make_node("Cast", ["input"], ["fp32_input"], "cast_input", to=1), onnx.helper.make_node("Cast", ["layer_norm.weight"], ["fp32_layer_norm.weight"], "cast_weight", to=1), onnx.helper.make_node("Cast", ["layer_norm.bias"], ["fp32_layer_norm.bias"], "cast_bias", to=1), onnx.helper.make_node( "LayerNormalization", ["fp32_input", "fp32_layer_norm.weight", "fp32_layer_norm.bias"], ["fp32_output"], "layer_norm", epsilon=epsilon), # use fp32 epsilon onnx.helper.make_node("Cast", ["fp32_output"], ["output"], "cast_output", to=10) ] onnx_model.remove_nodes(nodes_to_remove) onnx_model.add_nodes(nodes_to_add) onnx_model.prune_graph() onnx_model.save_model_to_file(optimized_onnx_path)
def optimize_fp16_onnx_no_cast(input_onnx_path, optimized_onnx_path, epsilon): m = onnx.load(input_onnx_path) onnx_model = OnnxModel(m) weight_name = get_weight(onnx_model) bias_name = get_bias(onnx_model) nodes_to_remove = [n for n in onnx_model.nodes() if n.output[0] != weight_name and n.output[0] != bias_name] nodes_to_remove = onnx_model.nodes() node_to_add = onnx.helper.make_node("LayerNormalization", ["input", weight_name, bias_name], ["output"], "layer_norm", epsilon=epsilon) onnx_model.remove_nodes(nodes_to_remove) onnx_model.add_node(node_to_add) onnx_model.prune_graph() onnx_model.save_model_to_file(optimized_onnx_path)
def optimize_fp16_onnx_no_cast(input_onnx_path, optimized_onnx_path, epsilon): m = onnx.load(input_onnx_path) onnx_model = OnnxModel(m) nodes_to_remove = onnx_model.nodes() node_to_add = onnx.helper.make_node( "LayerNormalization", ["input", "layer_norm.weight", "layer_norm.bias"], ["output"], "layer_norm", epsilon=epsilon) onnx_model.remove_nodes(nodes_to_remove) onnx_model.add_node(node_to_add) onnx_model.prune_graph() onnx_model.save_model_to_file(optimized_onnx_path)
def remove_useless_reshape_nodes(model: OnnxModel): """Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape """ shape_infer = model.infer_runtime_shape(update=True) if shape_infer is None: return nodes_to_remove = [] for node in model.nodes(): if node.op_type == 'Reshape': input_shape = shape_infer.get_edge_shape(node.input[0]) output_shape = shape_infer.get_edge_shape(node.output[0]) if input_shape and output_shape and input_shape == output_shape: logger.info( f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}" ) nodes_to_remove.append(node) if nodes_to_remove: for node in nodes_to_remove: model.replace_input_of_all_nodes(node.output[0], node.input[0]) model.remove_node(node) model.prune_graph()
def optimize_fp16_onnx_with_cast(input_onnx_path, optimized_onnx_path, epsilon): m = onnx.load(input_onnx_path) onnx_model = OnnxModel(m) weight_name = get_weight(onnx_model) bias_name = get_bias(onnx_model) nodes_to_remove = [ n for n in onnx_model.nodes() if n.output[0] != weight_name and n.output[0] != bias_name ] nodes_to_add = [ onnx.helper.make_node("Cast", ["input"], ["fp32_input"], "cast_input", to=1), onnx.helper.make_node("Cast", [weight_name], ["fp32_layer_norm.weight"], "cast_weight", to=1), onnx.helper.make_node("Cast", [bias_name], ["fp32_layer_norm.bias"], "cast_bias", to=1), onnx.helper.make_node( "LayerNormalization", ["fp32_input", "fp32_layer_norm.weight", "fp32_layer_norm.bias"], ["fp32_output"], "layer_norm", epsilon=epsilon, ), # use fp32 epsilon onnx.helper.make_node("Cast", ["fp32_output"], ["output"], "cast_output", to=10), ] onnx_model.remove_nodes(nodes_to_remove) onnx_model.add_nodes(nodes_to_add) onnx_model.prune_graph() onnx_model.save_model_to_file(optimized_onnx_path)
def auto_mixed_precision(onnx_model: OnnxModel, op_block_list: List[str] = [ 'Add', 'LayerNormalization', 'FastGelu' ]): """Convert GPT-2 model to mixed precision. It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically. Args: onnx_model (OnnxModel): optimized ONNX model op_block_list (List[str], optional): . Defaults to ['Add', 'LayerNormalization', 'FastGelu'] Returns: parameters(dict): a dictionary of parameters used in float16 conversion """ op_full_set = set([node.op_type for node in onnx_model.nodes()]) fp32_op_set = set(op_block_list) fp16_op_set = op_full_set.difference(fp32_op_set) logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}") # logits is the first output logits_output_name = onnx_model.graph().output[0].name # We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training. is_weight_fp16_precision = False output_name_to_node = onnx_model.output_name_to_node() assert logits_output_name in output_name_to_node node = output_name_to_node[logits_output_name] last_matmul_node = None if node.op_type == "MatMul": last_matmul_node = node logger.info(f"Found last MatMul node for logits: {node.name}") initializer = None for input in node.input: initializer = onnx_model.get_initializer(input) if initializer is not None: break # when the max difference of value after converting float to float16 is lower than a threshold (1e-6), # we can deduce that the weights are stored in float16 precision. max_diff = float_to_float16_max_diff(initializer) logger.debug( f"max diff of converting weights in last MatMul node {node.name}: {max_diff}" ) is_weight_fp16_precision = (max_diff < 1E-6) else: logger.warning( f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}" ) if is_weight_fp16_precision: keep_io_types = [] node_block_list = [] else: # When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision. keep_io_types = [logits_output_name] node_block_list = [last_matmul_node.name] parameters = { "keep_io_types": keep_io_types, "op_block_list": op_block_list, "node_block_list": node_block_list, "force_fp16_initializers": is_weight_fp16_precision } logger.info(f"auto_mixed_precision parameters: {parameters}") onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters) fusion_utils = FusionUtils(onnx_model) fusion_utils.remove_cascaded_cast_nodes() fusion_utils.remove_useless_cast_nodes() return parameters