def test_gpt2_past_fp16(self): input_model_path = _get_test_model_path('gpt2_past') model = OnnxModel(load_model(input_model_path, format=None, load_external_data=True)) model.convert_float_to_float16(keep_io_types=False, use_symbolic_shape_infer=False) for input in model.graph().input[1:]: self.assertEqual(input.type.tensor_type.elem_type, TensorProto.FLOAT16) for output in model.graph().output: self.assertEqual(output.type.tensor_type.elem_type, TensorProto.FLOAT16)
def auto_mixed_precision(onnx_model: OnnxModel, op_block_list: List[str] = [ 'Add', 'LayerNormalization', 'FastGelu' ]): """Convert GPT-2 model to mixed precision. It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically. Args: onnx_model (OnnxModel): optimized ONNX model op_block_list (List[str], optional): . Defaults to ['Add', 'LayerNormalization', 'FastGelu'] Returns: parameters(dict): a dictionary of parameters used in float16 conversion """ op_full_set = set([node.op_type for node in onnx_model.nodes()]) fp32_op_set = set(op_block_list) fp16_op_set = op_full_set.difference(fp32_op_set) logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}") # logits is the first output logits_output_name = onnx_model.graph().output[0].name # We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training. is_weight_fp16_precision = False output_name_to_node = onnx_model.output_name_to_node() assert logits_output_name in output_name_to_node node = output_name_to_node[logits_output_name] last_matmul_node = None if node.op_type == "MatMul": last_matmul_node = node logger.info(f"Found last MatMul node for logits: {node.name}") initializer = None for input in node.input: initializer = onnx_model.get_initializer(input) if initializer is not None: break # when the max difference of value after converting float to float16 is lower than a threshold (1e-6), # we can deduce that the weights are stored in float16 precision. max_diff = float_to_float16_max_diff(initializer) logger.debug( f"max diff of converting weights in last MatMul node {node.name}: {max_diff}" ) is_weight_fp16_precision = (max_diff < 1E-6) else: logger.warning( f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}" ) if is_weight_fp16_precision: keep_io_types = [] node_block_list = [] else: # When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision. keep_io_types = [logits_output_name] node_block_list = [last_matmul_node.name] parameters = { "keep_io_types": keep_io_types, "op_block_list": op_block_list, "node_block_list": node_block_list, "force_fp16_initializers": is_weight_fp16_precision } logger.info(f"auto_mixed_precision parameters: {parameters}") onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters) fusion_utils = FusionUtils(onnx_model) fusion_utils.remove_cascaded_cast_nodes() fusion_utils.remove_useless_cast_nodes() return parameters