def prepare(cls, model, device=None, **kwargs): """ Load the model and creates a :class:`onnxruntime.InferenceSession` ready to be used as a backend. :param model: ModelProto (returned by `onnx.load`), string for a filename or bytes for a serialized model :param device: requested device for the computation, None means the default one which depends on the compilation settings :param kwargs: see :class:`onnxruntime.SessionOptions` :return: :class:`onnxruntime.InferenceSession` """ if isinstance(model, OnnxRuntimeBackendRep): return model elif isinstance(model, InferenceSession): return OnnxRuntimeBackendRep(model) elif isinstance(model, (str, bytes)): options = SessionOptions() for k, v in kwargs.items(): if hasattr(options, k): setattr(options, k, v) inf = InferenceSession(model, options) # backend API is primarily used for ONNX test/validation. As such, we should disable session.run() fallback # which may hide test failures. inf.disable_fallback() if device is not None and not cls.supports_device(device): raise RuntimeError("Incompatible device expected '{0}', got '{1}'".format(device, get_device())) return cls.prepare(inf, device, **kwargs) else: # type: ModelProto check_model(model) bin = model.SerializeToString() return cls.prepare(bin, device, **kwargs)
def prepare(cls, model, device=None, **kwargs): """ Load the model and creates a :class:`onnxruntime.InferenceSession` ready to be used as a backend. :param model: ModelProto (returned by `onnx.load`), string for a filename or bytes for a serialized model :param device: requested device for the computation, None means the default one which depends on the compilation settings :param kwargs: see :class:`onnxruntime.SessionOptions` :return: :class:`onnxruntime.InferenceSession` """ if isinstance(model, OnnxRuntimeBackendRep): return model elif isinstance(model, InferenceSession): return OnnxRuntimeBackendRep(model) elif isinstance(model, (str, bytes)): options = SessionOptions() for k, v in kwargs.items(): if hasattr(options, k): setattr(options, k, v) excluded_providers = os.getenv("ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS", default="").split(",") providers = [x for x in get_available_providers() if (x not in excluded_providers)] inf = InferenceSession(model, sess_options=options, providers=providers) # backend API is primarily used for ONNX test/validation. As such, we should disable session.run() fallback # which may hide test failures. inf.disable_fallback() if device is not None and not cls.supports_device(device): raise RuntimeError("Incompatible device expected '{0}', got '{1}'".format(device, get_device())) return cls.prepare(inf, device, **kwargs) else: # type: ModelProto # check_model serializes the model anyways, so serialize the model once here # and reuse it below in the cls.prepare call to avoid an additional serialization # only works with onnx >= 1.10.0 hence the version check onnx_version = tuple(map(int, (version.version.split(".")[:3]))) onnx_supports_serialized_model_check = onnx_version >= (1, 10, 0) bin_or_model = model.SerializeToString() if onnx_supports_serialized_model_check else model check_model(bin_or_model) opset_supported, error_message = cls.is_opset_supported(model) if not opset_supported: raise unittest.SkipTest(error_message) # Now bin might be serialized, if it's not we need to serialize it otherwise we'll have # an infinite recursive call bin = bin_or_model if not isinstance(bin, (str, bytes)): bin = bin.SerializeToString() return cls.prepare(bin, device, **kwargs)