def verify( model: Union[torch.nn.Module, torch.jit.ScriptModule], input_args: Tuple[Any, ...], input_kwargs: Optional[Mapping[str, Any]] = None, do_constant_folding: bool = True, dynamic_axes: Optional[ Mapping[str, Union[Mapping[int, str], Mapping[str, Sequence[int]]]] ] = None, input_names: Optional[Sequence[str]] = None, output_names: Optional[Sequence[str]] = None, training: torch.onnx.TrainingMode = torch.onnx.TrainingMode.EVAL, opset_version: Optional[int] = None, keep_initializers_as_inputs: bool = True, verbose: bool = False, fixed_batch_size: bool = False, use_external_data: bool = False, additional_test_inputs: Optional[Sequence[Tuple[Any, ...]]] = None, remained_onnx_input_idx: Optional[Sequence[int]] = None, flatten: bool = True, ort_providers: Sequence[str] = _ORT_PROVIDERS, rtol: float = 0.001, atol: float = 1e-7, **_, ): """Verify model export to ONNX with ONNX Runtime. Args: model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`. input_args (tuple): See :func:`torch.onnx.export`. input_kwargs (dict): See :func:`torch.onnx.export`. do_constant_folding (bool, optional): See :func:`torch.onnx.export`. dynamic_axes (dict, optional): See :func:`torch.onnx.export`. input_names (list, optional): See :func:`torch.onnx.export`. output_names (list, optional): See :func:`torch.onnx.export`. training (torch.onnx.TrainingMode): See :func:`torch.onnx.export`. opset_version (int, optional): See :func:`torch.onnx.export`. keep_initializers_as_inputs (bool, optional): See :func:`torch.onnx.export`. verbose (bool, optional): See :func:`torch.onnx.export`. fixed_batch_size (bool, optional): Legacy argument, used only by rnn test cases. use_external_data (bool, optional): Explicitly specify whether to export the model with external data. additional_test_inputs (list, optional): List of tuples. Each tuple is a group of input arguments to test. Currently only *args are supported. remained_onnx_input_idx (list, optional): If provided, only the specified inputs will be passed to the ONNX model. Supply a list when there are unused inputs in the model. Since unused inputs will be removed in the exported ONNX model, supplying all inputs will cause an error on unexpected inputs. This parameter tells the verifier which inputs to pass into the ONNX model. flatten (bool, optional): Default True. If True, unpack nested list/tuple/dict inputs into a flattened list of Tensors for ONNX. Set this to False if nested structures are to be preserved for ONNX, which is usually the case with exporting ScriptModules. ort_providers (sequence, optional): ONNX Runtime providers to use. rtol (float, optional): relative tolerance in comparison between ONNX and PyTorch outputs. atol (float, optional): absolute tolerance in comparison between ONNX and PyTorch outputs. Raises: AssertionError: if outputs from ONNX model and PyTorch model are not equal up to specified precision. """ if training == torch.onnx.TrainingMode.TRAINING: model.train() elif training == torch.onnx.TrainingMode.EVAL: model.eval() with torch.no_grad(), contextlib.ExitStack() as stack: model_f: Union[str, io.BytesIO] = io.BytesIO() if use_external_data: tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory()) model_f = os.path.join(tmpdir_path, "model.onnx") inputs_for_export = _prepare_input_for_export(input_args, input_kwargs) # TODO(#77679): remove this and treat mutating model separately. model_copy = _try_clone_model(model) utils._export( model, inputs_for_export, model_f, opset_version=opset_version, do_constant_folding=do_constant_folding, keep_initializers_as_inputs=keep_initializers_as_inputs, dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names, fixed_batch_size=fixed_batch_size, training=training, verbose=verbose, ) ort_session = _ort_session(model_f, ort_providers) _compare_ort_pytorch_model( model_copy, ort_session, input_args, input_kwargs, additional_test_inputs, remained_onnx_input_idx, flatten, rtol, atol, )
def _export(*args, **kwargs): from torch.onnx import utils result = utils._export(*args, **kwargs) return result
def _export(*args, **kwargs): from torch.onnx import utils return utils._export(*args, **kwargs)