def test_lightgbm_classifier(self): model = LGBMClassifier(n_estimators=3, min_child_samples=1) dump_binary_classification( model, allow_failure= "StrictVersion(onnx.__version__) < StrictVersion('1.3.0')") dump_multiple_classification( model, allow_failure= "StrictVersion(onnx.__version__) < StrictVersion('1.3.0')")
def test_catboost_multi_classifier(self): X, y = make_classification(n_samples=10, n_informative=8, n_classes=3, random_state=0) catboost_model = catboost.CatBoostClassifier(task_type='CPU', loss_function='MultiClass', n_estimators=100, verbose=0) dump_multiple_classification(catboost_model) catboost_model.fit(X.astype(numpy.float32), y) catboost_onnx = convert_catboost(catboost_model, name='CatBoostMultiClassification', doc_string='test multiclass classification') self.assertTrue(catboost_onnx is not None) dump_data_and_model(X.astype(numpy.float32), catboost_model, catboost_onnx, basename="CatBoostMultiClass")
def test_xgb_classifier_multi(self): iris = load_iris() X = iris.data[:, :2] y = iris.target xgb = XGBClassifier() xgb.fit(X, y) conv_model = convert_xgboost(xgb, initial_types=[ ('input', FloatTensorType(shape=[1, 'None'])) ]) self.assertTrue(conv_model is not None) dump_multiple_classification( xgb, allow_failure= "StrictVersion(onnx.__version__) < StrictVersion('1.3.0')")
def test_extra_trees_classifier(self): model = ExtraTreesClassifier(n_estimators=3) dump_one_class_classification(model) dump_binary_classification(model) dump_multiple_classification(model)
def test_random_forest_classifier(self): model = RandomForestClassifier(n_estimators=3) dump_one_class_classification(model) dump_binary_classification(model) dump_multiple_classification(model)
def test_decision_tree_classifier(self): model = DecisionTreeClassifier() dump_one_class_classification(model) dump_binary_classification(model) dump_multiple_classification(model)