def export_pytorch( tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path, ) -> Tuple[List[str], List[str]]: """ Export a PyTorch model to an ONNX Intermediate Representation (IR) Args: tokenizer ([`PreTrainedTokenizer`]): The tokenizer used for encoding the data. model ([`PreTrainedModel`]): The model to export. config ([`~onnx.config.OnnxConfig`]): The ONNX configuration associated with the exported model. opset (`int`): The version of the ONNX operator set to use. output (`Path`): Directory to store the exported ONNX model. Returns: `Tuple[List[str], List[str]]`: A tuple with an ordered list of the model's inputs, and the named inputs from the ONNX configuration. """ if issubclass(type(model), PreTrainedModel): import torch from torch.onnx import export as onnx_export logger.info(f"Using framework PyTorch: {torch.__version__}") with torch.no_grad(): model.config.return_dict = True model.eval() # Check if we need to override certain configuration item if config.values_override is not None: logger.info( f"Overriding {len(config.values_override)} configuration item(s)" ) for override_config_key, override_config_value in config.values_override.items( ): logger.info( f"\t- {override_config_key} -> {override_config_value}" ) setattr(model.config, override_config_key, override_config_value) # Ensure inputs match # TODO: Check when exporting QA we provide "is_pair=True" model_inputs = config.generate_dummy_inputs( tokenizer, framework=TensorType.PYTORCH) inputs_match, matched_inputs = ensure_model_and_config_inputs_match( model, model_inputs.keys()) onnx_outputs = list(config.outputs.keys()) if not inputs_match: raise ValueError("Model and config inputs doesn't match") config.patch_ops() # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, # so we check the torch version for backwards compatibility if parse(torch.__version__) <= parse("1.10.99"): # export can work with named args but the dict containing named args # has to be the last element of the args tuple. onnx_export( model, (model_inputs, ), f=output.as_posix(), input_names=list(config.inputs.keys()), output_names=onnx_outputs, dynamic_axes={ name: axes for name, axes in chain(config.inputs.items(), config.outputs.items()) }, do_constant_folding=True, use_external_data_format=config.use_external_data_format( model.num_parameters()), enable_onnx_checker=True, opset_version=opset, ) else: onnx_export( model, (model_inputs, ), f=output.as_posix(), input_names=list(config.inputs.keys()), output_names=onnx_outputs, dynamic_axes={ name: axes for name, axes in chain(config.inputs.items(), config.outputs.items()) }, do_constant_folding=True, opset_version=opset, ) config.restore_ops() return matched_inputs, onnx_outputs
def export( tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path ) -> Tuple[List[str], List[str]]: """ Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR Args: tokenizer: model: config: opset: output: Returns: """ if not is_torch_available(): raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.") import torch from torch.onnx import export from ..file_utils import torch_version if not is_torch_onnx_dict_inputs_support_available(): raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}") logger.info(f"Using framework PyTorch: {torch.__version__}") with torch.no_grad(): model.config.return_dict = True model.eval() # Check if we need to override certain configuration item if config.values_override is not None: logger.info(f"Overriding {len(config.values_override)} configuration item(s)") for override_config_key, override_config_value in config.values_override.items(): logger.info(f"\t- {override_config_key} -> {override_config_value}") setattr(model.config, override_config_key, override_config_value) # Ensure inputs match # TODO: Check when exporting QA we provide "is_pair=True" model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH) inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) onnx_outputs = list(config.outputs.keys()) if not inputs_match: raise ValueError("Model and config inputs doesn't match") config.patch_ops() # export can works with named args but the dict containing named args as to be last element of the args tuple export( model, (model_inputs,), f=output.as_posix(), input_names=list(config.inputs.keys()), output_names=onnx_outputs, dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())}, do_constant_folding=True, use_external_data_format=config.use_external_data_format(model.num_parameters()), enable_onnx_checker=True, opset_version=opset, ) config.restore_ops() return matched_inputs, onnx_outputs