def export_onnx( model, device, onnx_model_path: str, verbose: bool = False, use_external_data_format: bool = False, has_position_ids: bool = True, has_attention_mask: bool = True, input_ids_dtype: torch.dtype = torch.int32, position_ids_dtype: torch.dtype = torch.int32, attention_mask_dtype: torch.dtype = torch.int32, ): """Export GPT-2 model with past state to ONNX model.""" config: GPT2Config = model.config num_layer = config.n_layer dummy_inputs = Gpt2Helper.get_dummy_inputs( batch_size=1, past_sequence_length=1, sequence_length=1, num_attention_heads=config.num_attention_heads, hidden_size=config.hidden_size, num_layer=num_layer, vocab_size=config.vocab_size, device=device, float16=False, has_position_ids=has_position_ids, has_attention_mask=has_attention_mask, input_ids_dtype=input_ids_dtype, position_ids_dtype=position_ids_dtype, attention_mask_dtype=attention_mask_dtype, ) input_list = dummy_inputs.to_list() with torch.no_grad(): outputs = model(*input_list) past_names = [f"past_{i}" for i in range(num_layer)] present_names = [f"present_{i}" for i in range(num_layer)] # GPT2Model outputs last_state; GPT2LMHeadModel outputs logits (prediction_scores) assert outputs[0].shape[2] == config.vocab_size or outputs[0].shape[ 2] == config.hidden_size output_names = [ "logits" if outputs[0].shape[2] == config.vocab_size else "last_state" ] + present_names # Shape of input tensors: # input_ids: (batch_size, seq_len) # past_{i}: (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads) # attention_mask: (batch_size, past_seq_len + seq_len) # Shape of output tensors: # last_state: (batch_size, seq_len, hidden_size) # or logits: (batch_size, seq_len, vocab_size) # present_{i}: (2, batch_size, num_heads, past_seq_len + seq_len, hidden_size/num_heads) dynamic_axes = { "input_ids": { 0: "batch_size", 1: "seq_len" }, output_names[0]: { 0: "batch_size", 1: "seq_len" }, } for name in past_names: dynamic_axes[name] = {1: "batch_size", 3: "past_seq_len"} for name in present_names: dynamic_axes[name] = {1: "batch_size", 3: "total_seq_len"} input_names = ["input_ids"] if has_position_ids: dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} input_names.append("position_ids") if has_attention_mask: dynamic_axes["attention_mask"] = { 0: "batch_size", 1: "total_seq_len" } input_names.append("attention_mask") input_names.extend(past_names) assert len(outputs) == 2 and len(outputs[1]) == num_layer logger.info( f"Shapes: input_ids={dummy_inputs.input_ids.shape} past={dummy_inputs.past[0].shape} output={outputs[0].shape} present={outputs[1][0].shape}" ) Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) if use_external_data_format: # We let PyTorch export onnx to a temp directory first, then convert external data to one file. with tempfile.TemporaryDirectory() as tmp_dir_name: temp_onnx_model_path = os.path.join(tmp_dir_name, "gpt2.onnx") Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True) torch_onnx_export( model, args=tuple(input_list), f=temp_onnx_model_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11, do_constant_folding=True, use_external_data_format=True, verbose=verbose, ) model = onnx.load_model(temp_onnx_model_path, load_external_data=True) OnnxModel.save( model, onnx_model_path, save_as_external_data=True, all_tensors_to_one_file=True, ) else: torch_onnx_export( model, args=tuple(input_list), f=onnx_model_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=11, do_constant_folding=True, use_external_data_format=False, verbose=verbose, )
def export_onnx( decoder: Union[T5Decoder, T5DecoderInit], device: torch.device, onnx_model_path: str, verbose: bool = True, use_external_data_format: bool = False, use_int32_inputs: bool = False, ): """Export decoder to ONNX Args: decoder (Union[T5Decoder, T5DecoderNoPastState]): decoder object device (torch.device): device of decoder object onnx_model_path (str): onnx path verbose (bool, optional): print verbose information. Defaults to True. use_external_data_format (bool, optional): use external data format or not. Defaults to False. use_int32_inputs (bool, optional): use int32 inputs """ assert isinstance(decoder, (T5Decoder, T5DecoderInit)) inputs = T5DecoderInputs.create_dummy( decoder.config, batch_size=2, encode_sequence_length=3, past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0, device=device, use_int32_inputs=use_int32_inputs, ) input_list = inputs.to_list() past_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=False) present_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=True) present_self_names = present_names[: 2 * decoder.config.num_layers] input_past_names = past_names if isinstance(decoder, T5Decoder) else [] output_present_names = present_self_names if isinstance(decoder, T5Decoder) else present_names output_names = ["logits"] + output_present_names # Shape of input tensors (sequence_length==1): # input_ids: (batch_size, sequence_length) # encoder_attention_mask: (batch_size, encode_sequence_length) # encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size) # past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size) # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size) # Shape of output tensors: # logits: (batch_size, sequence_length, vocab_size) # past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size) # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size) input_names = ["input_ids"] input_names.append("encoder_attention_mask") input_names.append("encoder_hidden_states") input_names.extend(input_past_names) dynamic_axes = { "input_ids": { 0: "batch_size", # 1: 'sequence_length' }, "encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"}, "encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length"}, "logits": { 0: "batch_size", # 1: 'sequence_length' }, } for name in input_past_names: dynamic_axes[name] = { 0: "batch_size", 2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length", } for name in output_present_names: if "cross" in name: dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"} else: # self attention past state if isinstance(decoder, T5Decoder): dynamic_axes[name] = { 0: "batch_size", 2: "past_decode_sequence_length + 1", } else: dynamic_axes[name] = { 0: "batch_size", # 2: 'sequence_length' } Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) with tempfile.TemporaryDirectory() as tmp_dir_name: temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx") Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True) torch_onnx_export( decoder, args=tuple(input_list), f=temp_onnx_model_path if use_external_data_format else onnx_model_path, export_params=True, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, opset_version=12, do_constant_folding=True, use_external_data_format=use_external_data_format, verbose=verbose, ) if use_external_data_format: model = onnx.load_model(temp_onnx_model_path, load_external_data=True) OnnxModel.save( model, onnx_model_path, save_as_external_data=True, all_tensors_to_one_file=True, )
def export_onnx( encoder: T5Encoder, device: torch.device, onnx_model_path: str, verbose: bool = True, use_external_data_format: bool = False, use_int32_inputs: bool = False, ): """Export encoder to ONNX Args: encoder (T5Encoder): encoder object device (torch.device): device of encoder object onnx_model_path (str): onnx path verbose (bool, optional): print verbose information. Defaults to True. use_external_data_format (bool, optional): use external data format or not. Defaults to False. """ config = encoder.config encoder_inputs = T5EncoderInputs.create_dummy( batch_size=2, sequence_length=4, vocab_size=config.vocab_size, device=device, use_int32_inputs=use_int32_inputs, ) Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) with tempfile.TemporaryDirectory() as tmp_dir_name: temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx") Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True) torch_onnx_export( encoder, args=tuple(encoder_inputs.to_list()), f=temp_onnx_model_path if use_external_data_format else onnx_model_path, export_params=True, input_names=["input_ids", "attention_mask"], output_names=["hidden_states"], dynamic_axes={ "input_ids": { 0: "batch_size", 1: "sequence_length" }, "attention_mask": { 0: "batch_size", 1: "sequence_length" }, "hidden_states": { 0: "batch_size", 1: "sequence_length" }, }, opset_version=12, do_constant_folding=True, use_external_data_format=use_external_data_format, verbose=verbose, ) if use_external_data_format: model = onnx.load_model(temp_onnx_model_path, load_external_data=True) OnnxModel.save( model, onnx_model_path, save_as_external_data=True, all_tensors_to_one_file=True, )