def __setstate__(self, state): if get_library_path is None: raise ImportError("onnxruntime_extensions is not installed.") state['onnx_'] = load(BytesIO(state['onnx_'])) BaseEstimator.__setstate__(self, state) so = SessionOptions() so.register_custom_ops_library(get_library_path()) self.sess_ = InferenceSession(self.onnx_.SerializeToString(), so) return self
def fit(self, X, y=None, sample_weight=None): """ The model is not trains this method is still needed to set the instance up and ready to transform. :param X: array of strings :param y: unused :param sample_weight: unused :return: self """ self.onnx_ = self._create_model(self.model_b64, opset=self.opset) so = SessionOptions() so.register_custom_ops_library(get_library_path()) self.sess_ = InferenceSession(self.onnx_.SerializeToString(), so) return self
def common_test_gpc(self, dtype=np.float32, n_classes=2): gp = GaussianProcessClassifier() gp, X = self.fit_classification_model(gp, n_classes=n_classes) # return_cov=False, return_std=False if dtype == np.float32: cls = FloatTensorType else: cls = DoubleTensorType model_onnx = to_onnx(gp, initial_types=[('X', cls([None, None]))], target_opset=TARGET_OPSET, options={ GaussianProcessClassifier: { 'zipmap': False, 'optim': 'cdist' } }) self.assertTrue(model_onnx is not None) try: sess = InferenceSession(model_onnx.SerializeToString()) except OrtFail: if not hasattr(self, 'path'): return suffix = 'Double' if dtype == np.float64 else 'Float' # Operator Solve is missing model_onnx = change_onnx_domain( model_onnx, {'Solve': ('Solve%s' % suffix, 'ai.onnx.contrib')}) so = SessionOptions() so.register_custom_ops_library(self.path) sess = InferenceSession(model_onnx.SerializeToString(), so) res = sess.run(None, {'X': X.astype(dtype)}) assert_almost_equal(res[0].ravel(), gp.predict(X).ravel()) assert_almost_equal(res[1], gp.predict_proba(X), decimal=3) return dt = 32 if dtype == np.float32 else 64 dump_data_and_model(X.astype(dtype), gp, model_onnx, verbose=False, basename="SklearnGaussianProcessRBFT%d%d" % (n_classes, dt))