def setUp(self): logger = getLogger('skl2onnx') logger.disabled = True register_converters() X = numpy.abs(numpy.random.randn(10, 200)).astype(numpy.float32) for i in range(X.shape[1]): X[:, i] *= (i + 1) * 10 y = X.sum(axis=1) / 1e3 + numpy.random.randn(X.shape[0]).astype( numpy.float32) X = X.astype(numpy.float32) y = y.astype(numpy.float32) self.data_X, self.data_y = X, y
def setup(self, runtime, N, nf, opset, dtype, optim): "asv API" logger = getLogger('skl2onnx') logger.disabled = True register_converters() register_rewritten_operators() with open(self._name(nf, opset, dtype), "rb") as f: stored = pickle.load(f) self.stored = stored self.model = stored['model'] self.X, self.y = make_n_rows(stored['X'], N, stored['y']) onx, rt_, rt_fct_, rt_fct_track_ = self._create_onnx_and_runtime( runtime, self.model, self.X, opset, dtype, optim) self.onx = onx setattr(self, "rt_" + runtime, rt_) setattr(self, "rt_fct_" + runtime, rt_fct_) setattr(self, "rt_fct_track_" + runtime, rt_fct_track_) set_config(assume_finite=True)
X, y = load_breast_cancer(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y) results['breast_cancer'] = [X_train, X_test, y_train, y_test] X, y = load_digits(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=41) results['digits'] = [X_train, X_test, y_train, y_test] X, y = make_classification(20000, 20) X_train, X_test, y_train, y_test = train_test_split(X, y) results['rndbin100'] = [X_train, X_test, y_train, y_test] return results register_converters() common_datasets = create_datasets() def get_model(lib): if lib == "sklh": return HistGradientBoostingRegressor(max_depth=6, max_iter=100) if lib == "skl": return RandomForestRegressor(max_depth=6, n_estimators=100) if lib == 'xgb': return XGBRegressor(max_depth=6, n_estimators=100) if lib == 'lgb': return LGBMRegressor(max_depth=6, n_estimators=100) raise ValueError("Unknown library '{}'.".format(lib))
def setUp(self): logger = getLogger('skl2onnx') logger.disabled = True register_converters()
def test_register_converters(self): with warnings.catch_warnings(): warnings.simplefilter("ignore", ResourceWarning) res = register_converters(True) self.assertGreater(len(res), 2)