Beispiel #1
0
def verify(
    model: Union[torch.nn.Module, torch.jit.ScriptModule],
    input_args: Tuple[Any, ...],
    input_kwargs: Optional[Mapping[str, Any]] = None,
    do_constant_folding: bool = True,
    dynamic_axes: Optional[
        Mapping[str, Union[Mapping[int, str], Mapping[str, Sequence[int]]]]
    ] = None,
    input_names: Optional[Sequence[str]] = None,
    output_names: Optional[Sequence[str]] = None,
    training: torch.onnx.TrainingMode = torch.onnx.TrainingMode.EVAL,
    opset_version: Optional[int] = None,
    keep_initializers_as_inputs: bool = True,
    verbose: bool = False,
    fixed_batch_size: bool = False,
    use_external_data: bool = False,
    additional_test_inputs: Optional[Sequence[Tuple[Any, ...]]] = None,
    remained_onnx_input_idx: Optional[Sequence[int]] = None,
    flatten: bool = True,
    ort_providers: Sequence[str] = _ORT_PROVIDERS,
    rtol: float = 0.001,
    atol: float = 1e-7,
    **_,
):
    """Verify model export to ONNX with ONNX Runtime.

    Args:
        model (torch.nn.Module or torch.jit.ScriptModule): See :func:`torch.onnx.export`.
        input_args (tuple): See :func:`torch.onnx.export`.
        input_kwargs (dict): See :func:`torch.onnx.export`.
        do_constant_folding (bool, optional): See :func:`torch.onnx.export`.
        dynamic_axes (dict, optional): See :func:`torch.onnx.export`.
        input_names (list, optional): See :func:`torch.onnx.export`.
        output_names (list, optional): See :func:`torch.onnx.export`.
        training (torch.onnx.TrainingMode): See :func:`torch.onnx.export`.
        opset_version (int, optional): See :func:`torch.onnx.export`.
        keep_initializers_as_inputs (bool, optional): See :func:`torch.onnx.export`.
        verbose (bool, optional): See :func:`torch.onnx.export`.
        fixed_batch_size (bool, optional): Legacy argument, used only by rnn test cases.
        use_external_data (bool, optional): Explicitly specify whether to export the
            model with external data.
        additional_test_inputs (list, optional): List of tuples. Each tuple is a group of
            input arguments to test. Currently only *args are supported.
        remained_onnx_input_idx (list, optional): If provided, only the specified inputs
            will be passed to the ONNX model. Supply a list when there are unused inputs
            in the model. Since unused inputs will be removed in the exported ONNX
            model, supplying all inputs will cause an error on unexpected inputs.
            This parameter tells the verifier which inputs to pass into the ONNX model.
        flatten (bool, optional): Default True. If True, unpack nested list/tuple/dict
            inputs into a flattened list of Tensors for ONNX. Set this to False if nested
            structures are to be preserved for ONNX, which is usually the case with
            exporting ScriptModules.
        ort_providers (sequence, optional): ONNX Runtime providers to use.
        rtol (float, optional): relative tolerance in comparison between ONNX and PyTorch outputs.
        atol (float, optional): absolute tolerance in comparison between ONNX and PyTorch outputs.

    Raises:
        AssertionError: if outputs from ONNX model and PyTorch model are not
            equal up to specified precision.
    """
    if training == torch.onnx.TrainingMode.TRAINING:
        model.train()
    elif training == torch.onnx.TrainingMode.EVAL:
        model.eval()
    with torch.no_grad(), contextlib.ExitStack() as stack:
        model_f: Union[str, io.BytesIO] = io.BytesIO()
        if use_external_data:
            tmpdir_path = stack.enter_context(tempfile.TemporaryDirectory())
            model_f = os.path.join(tmpdir_path, "model.onnx")

        inputs_for_export = _prepare_input_for_export(input_args, input_kwargs)

        # TODO(#77679): remove this and treat mutating model separately.
        model_copy = _try_clone_model(model)
        utils._export(
            model,
            inputs_for_export,
            model_f,
            opset_version=opset_version,
            do_constant_folding=do_constant_folding,
            keep_initializers_as_inputs=keep_initializers_as_inputs,
            dynamic_axes=dynamic_axes,
            input_names=input_names,
            output_names=output_names,
            fixed_batch_size=fixed_batch_size,
            training=training,
            verbose=verbose,
        )

        ort_session = _ort_session(model_f, ort_providers)

        _compare_ort_pytorch_model(
            model_copy,
            ort_session,
            input_args,
            input_kwargs,
            additional_test_inputs,
            remained_onnx_input_idx,
            flatten,
            rtol,
            atol,
        )
Beispiel #2
0
def _export(*args, **kwargs):
    from torch.onnx import utils
    result = utils._export(*args, **kwargs)
    return result
Beispiel #3
0
def _export(*args, **kwargs):
    from torch.onnx import utils
    return utils._export(*args, **kwargs)