Beispiel #1
0
def _export_util(
    model: torch.nn.Module,
    args: Sequence[Any],
    f: IO,
    **kwargs: Any,
) -> Any:
    """Wrap operator type to export

    Copied from torch.onnx.utils.export, to get output values.
    """
    aten = kwargs.get('aten', False)
    export_raw_ir = kwargs.get('export_raw_ir', False)
    operator_export_type = kwargs.get('operator_export_type', None)

    if aten or export_raw_ir:
        assert operator_export_type is None
        assert aten ^ export_raw_ir
        # Note: OperatorExportTypes.RAW unavailable in PyTorch 1.10+
        operator_export_type = OperatorExportTypes.ONNX_ATEN if\
            aten else OperatorExportTypes.RAW  # type: ignore
    elif operator_export_type is None:
        if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
            operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
        else:
            operator_export_type = OperatorExportTypes.ONNX

    old_model_to_graph = torch.onnx.utils._model_to_graph
    # TODO(ecastill) _model_to_graph shouldn't be direclty overriden
    # This is a temporal workaround until a fix is introduced in PyTorch.
    try:
        torch.onnx.utils._model_to_graph = _model_to_graph_with_value_names
        if pytorch_pfn_extras.requires('1.10.0'):
            checker_error = getattr(torch.onnx, "CheckerError", None)
            if checker_error is None:
                checker_error = torch.onnx.utils.ONNXCheckerError  # type: ignore[attr-defined]
            try:
                enable_onnx_checker = kwargs.pop('enable_onnx_checker', None)
                return torch_export(  # type: ignore[no-untyped-call]
                    model, args, f, **kwargs)
            except checker_error:
                if enable_onnx_checker:
                    raise
        else:
            kwargs['_retain_param_name'] = True
            return torch_export(  # type: ignore[no-untyped-call]
                model, args, f, **kwargs)
    finally:
        torch.onnx.utils._model_to_graph = old_model_to_graph
def _export_util(model, args, f, **kwargs):
    """Wrap operator type to export

    Copied from torch.onnx.utils.export, to get output values.
    """
    aten = kwargs.get('aten', False)
    export_raw_ir = kwargs.get('export_raw_ir', False)
    operator_export_type = kwargs.get('operator_export_type', None)

    if aten or export_raw_ir:
        assert operator_export_type is None
        assert aten ^ export_raw_ir
        operator_export_type = OperatorExportTypes.ATEN if\
            aten else OperatorExportTypes.RAW
    elif operator_export_type is None:
        if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
            operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
        else:
            operator_export_type = OperatorExportTypes.ONNX

    old_model_to_graph = torch.onnx.utils._model_to_graph
    # TODO(ecastill) _model_to_graph shouldn't be direclty overriden
    # This is a temporal workaround until a fix is introduced in PyTorch.
    try:
        torch.onnx.utils._model_to_graph = _model_to_graph_with_value_names
        return torch_export(model, args, f, _retain_param_name=True, **kwargs)
    finally:
        torch.onnx.utils._model_to_graph = old_model_to_graph
def _export_util(model, args, f, **kwargs):
    """Wrap operator type to export

    Copied from torch.onnx.utils.export, to get output values.
    """
    aten = kwargs.get('aten', False)
    export_raw_ir = kwargs.get('export_raw_ir', False)
    operator_export_type = kwargs.get('operator_export_type', None)

    if aten or export_raw_ir:
        assert operator_export_type is None
        assert aten ^ export_raw_ir
        operator_export_type = OperatorExportTypes.ATEN if\
            aten else OperatorExportTypes.RAW
    elif operator_export_type is None:
        if torch.onnx.PYTORCH_ONNX_CAFFE2_BUNDLE:
            operator_export_type = OperatorExportTypes.ONNX_ATEN_FALLBACK
        else:
            operator_export_type = OperatorExportTypes.ONNX

    return torch_export(model, args, f, _retain_param_name=True, **kwargs)