Example #1
0
    def prepare(cls, model, device=None, **kwargs):
        """
        Load the model and creates a :class:`onnxruntime.InferenceSession`
        ready to be used as a backend.

        :param model: ModelProto (returned by `onnx.load`),
            string for a filename or bytes for a serialized model
        :param device: requested device for the computation,
            None means the default one which depends on
            the compilation settings
        :param kwargs: see :class:`onnxruntime.SessionOptions`
        :return: :class:`onnxruntime.InferenceSession`
        """
        if isinstance(model, OnnxRuntimeBackendRep):
            return model
        elif isinstance(model, InferenceSession):
            return OnnxRuntimeBackendRep(model)
        elif isinstance(model, (str, bytes)):
            options = SessionOptions()
            for k, v in kwargs.items():
                if hasattr(options, k):
                    setattr(options, k, v)
            inf = InferenceSession(model, options)
            # backend API is primarily used for ONNX test/validation. As such, we should disable session.run() fallback
            # which may hide test failures.
            inf.disable_fallback()
            if device is not None and not cls.supports_device(device):
                raise RuntimeError("Incompatible device expected '{0}', got '{1}'".format(device, get_device()))
            return cls.prepare(inf, device, **kwargs)
        else:
            # type: ModelProto
            check_model(model)
            bin = model.SerializeToString()
            return cls.prepare(bin, device, **kwargs)
Example #2
0
    def prepare(cls, model, device=None, **kwargs):
        """
        Load the model and creates a :class:`onnxruntime.InferenceSession`
        ready to be used as a backend.

        :param model: ModelProto (returned by `onnx.load`),
            string for a filename or bytes for a serialized model
        :param device: requested device for the computation,
            None means the default one which depends on
            the compilation settings
        :param kwargs: see :class:`onnxruntime.SessionOptions`
        :return: :class:`onnxruntime.InferenceSession`
        """
        if isinstance(model, OnnxRuntimeBackendRep):
            return model
        elif isinstance(model, InferenceSession):
            return OnnxRuntimeBackendRep(model)
        elif isinstance(model, (str, bytes)):
            options = SessionOptions()
            for k, v in kwargs.items():
                if hasattr(options, k):
                    setattr(options, k, v)

            excluded_providers = os.getenv("ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS", default="").split(",")
            providers = [x for x in get_available_providers() if (x not in excluded_providers)]

            inf = InferenceSession(model, sess_options=options, providers=providers)
            # backend API is primarily used for ONNX test/validation. As such, we should disable session.run() fallback
            # which may hide test failures.
            inf.disable_fallback()
            if device is not None and not cls.supports_device(device):
                raise RuntimeError("Incompatible device expected '{0}', got '{1}'".format(device, get_device()))
            return cls.prepare(inf, device, **kwargs)
        else:
            # type: ModelProto
            # check_model serializes the model anyways, so serialize the model once here
            # and reuse it below in the cls.prepare call to avoid an additional serialization
            # only works with onnx >= 1.10.0 hence the version check
            onnx_version = tuple(map(int, (version.version.split(".")[:3])))
            onnx_supports_serialized_model_check = onnx_version >= (1, 10, 0)
            bin_or_model = model.SerializeToString() if onnx_supports_serialized_model_check else model
            check_model(bin_or_model)
            opset_supported, error_message = cls.is_opset_supported(model)
            if not opset_supported:
                raise unittest.SkipTest(error_message)
            # Now bin might be serialized, if it's not we need to serialize it otherwise we'll have
            # an infinite recursive call
            bin = bin_or_model
            if not isinstance(bin, (str, bytes)):
                bin = bin.SerializeToString()
            return cls.prepare(bin, device, **kwargs)