コード例 #1
0
 def test_model_ridge_classifier_cv_multilabel(self):
     model, X_test = fit_multilabel_classification_model(
         linear_model.RidgeClassifierCV(random_state=42))
     model_onnx = convert_sklearn(
         model,
         "scikit-learn RidgeClassifierCV",
         [("input", FloatTensorType([None, X_test.shape[1]]))],
         target_opset=TARGET_OPSET)
     self.assertTrue(model_onnx is not None)
     dump_data_and_model(
         X_test, model, model_onnx,
         basename="SklearnRidgeClassifierCVMultiLabel")
コード例 #2
0
 def test_model_mlp_classifier_multilabel_tanh(self):
     model, X_test = fit_multilabel_classification_model(MLPClassifier(
         random_state=42, activation="tanh"),
                                                         n_labels=3)
     model_onnx = convert_sklearn(
         model,
         "scikit-learn MLPClassifier",
         [("input", FloatTensorType([None, X_test.shape[1]]))],
         target_opset=TARGET_OPSET)
     self.assertTrue(model_onnx is not None)
     dump_data_and_model(
         X_test,
         model,
         model_onnx,
         basename="SklearnMLPClassifierMultiLabelTanhActivation")
コード例 #3
0
 def test_model_mlp_classifier_multilabel_identity(self):
     model, X_test = fit_multilabel_classification_model(MLPClassifier(
         random_state=42, activation="identity"),
                                                         is_int=True)
     model_onnx = convert_sklearn(
         model,
         "scikit-learn MLPClassifier",
         [("input", Int64TensorType([None, X_test.shape[1]]))],
         target_opset=TARGET_OPSET)
     self.assertTrue(model_onnx is not None)
     dump_data_and_model(
         X_test,
         model,
         model_onnx,
         basename="SklearnMLPClassifierMultiLabelIdentityActivation")
 def test_model_ridge_classifier_cv_multilabel(self):
     model, X_test = fit_multilabel_classification_model(
         linear_model.RidgeClassifierCV(random_state=42))
     model_onnx = convert_sklearn(
         model,
         "scikit-learn RidgeClassifierCV",
         [("input", FloatTensorType([None, X_test.shape[1]]))],
     )
     self.assertTrue(model_onnx is not None)
     dump_data_and_model(
         X_test,
         model,
         model_onnx,
         basename="SklearnRidgeClassifierCVMultiLabel",
         allow_failure="StrictVersion("
         "onnxruntime.__version__)<= StrictVersion('0.2.1')",
     )
コード例 #5
0
 def test_model_extra_trees_classifier_multilabel(self):
     model, X_test = fit_multilabel_classification_model(
         ExtraTreesClassifier(random_state=42, n_estimators=5))
     options = {id(model): {'zipmap': False}}
     model_onnx = convert_sklearn(
         model,
         "scikit-learn ExtraTreesClassifier",
         [("input", FloatTensorType([None, X_test.shape[1]]))],
         options=options,
         target_opset=TARGET_OPSET)
     self.assertTrue(model_onnx is not None)
     assert 'zipmap' not in str(model_onnx).lower()
     dump_data_and_model(
         X_test,
         model,
         model_onnx,
         basename="SklearnExtraTreesClassifierMultiLabel-Out0")
コード例 #6
0
 def test_model_mlp_classifier_multilabel_tanh(self):
     model, X_test = fit_multilabel_classification_model(
         MLPClassifier(random_state=42, activation="tanh"), n_labels=3)
     model_onnx = convert_sklearn(
         model,
         "scikit-learn MLPClassifier",
         [("input", FloatTensorType([None, X_test.shape[1]]))],
     )
     self.assertTrue(model_onnx is not None)
     dump_data_and_model(
         X_test,
         model,
         model_onnx,
         basename="SklearnMLPClassifierMultiLabelTanhActivation",
         allow_failure="StrictVersion("
         "onnxruntime.__version__)<= StrictVersion('0.2.1')",
     )
コード例 #7
0
 def test_model_mlp_classifier_multilabel_default(self):
     model, X_test = fit_multilabel_classification_model(
         MLPClassifier(random_state=42))
     model_onnx = convert_sklearn(
         model,
         "scikit-learn MLPClassifier",
         [("input", FloatTensorType([None, X_test.shape[1]]))],
         target_opset=TARGET_OPSET
     )
     self.assertTrue(model_onnx is not None)
     dump_data_and_model(
         X_test,
         model,
         model_onnx,
         basename="SklearnMLPClassifierMultiLabel",
         allow_failure="StrictVersion("
         "onnxruntime.__version__)<= StrictVersion('0.2.1')",
     )
コード例 #8
0
 def test_model_extra_trees_classifier_multilabel_low_samples(self):
     model, X_test = fit_multilabel_classification_model(
         ExtraTreesClassifier(random_state=42), n_samples=10)
     options = {id(model): {'zipmap': False}}
     model_onnx = convert_sklearn(
         model,
         "scikit-learn ExtraTreesClassifier",
         [("input", FloatTensorType([None, X_test.shape[1]]))],
         options=options,
         target_opset=TARGET_OPSET)
     self.assertTrue(model_onnx is not None)
     assert 'zipmap' not in str(model_onnx).lower()
     dump_data_and_model(
         X_test,
         model,
         model_onnx,
         basename="SklearnExtraTreesClassifierMultiLabelLowSamples-Out0",
         allow_failure="StrictVersion("
         "onnxruntime.__version__) <= StrictVersion('0.2.1')",
     )
コード例 #9
0
 def test_model_knn_classifier_multilabel(self):
     model, X_test = fit_multilabel_classification_model(
         KNeighborsClassifier(),
         n_classes=7,
         n_labels=3,
         n_samples=100,
         n_features=10)
     options = {id(model): {'zipmap': False}}
     model_onnx = convert_sklearn(
         model,
         "scikit-learn KNN Classifier",
         [("input", FloatTensorType([None, X_test.shape[1]]))],
         options=options,
         target_opset=TARGET_OPSET)
     self.assertTrue(model_onnx is not None)
     assert 'zipmap' not in str(model_onnx).lower()
     dump_data_and_model(X_test,
                         model,
                         model_onnx,
                         basename="SklearnKNNClassifierMultiLabel-Out0")
コード例 #10
0
 def test_ovr_rf_multilabel_float(self):
     for opset in [12, TARGET_OPSET]:
         if opset > TARGET_OPSET:
             continue
         with self.subTest(opset=opset):
             model = OneVsRestClassifier(
                 RandomForestClassifier(n_estimators=2, max_depth=3))
             model, X = fit_multilabel_classification_model(model,
                                                            3,
                                                            is_int=False,
                                                            n_features=5)
             model_onnx = convert_sklearn(
                 model,
                 initial_types=[('input',
                                 FloatTensorType([None, X.shape[1]]))],
                 target_opset=opset)
             dump_data_and_model(X.astype(np.float32),
                                 model,
                                 model_onnx,
                                 basename="SklearnOVRRFMultiLabelFloat%d" %
                                 opset)
コード例 #11
0
 def test_ovr_rf_multilabel_int_11(self):
     for opset in [9, 10, 11]:
         if opset > TARGET_OPSET:
             continue
         with self.subTest(opset=opset):
             model = OneVsRestClassifier(
                 RandomForestClassifier(n_estimators=2, max_depth=3))
             model, X = fit_multilabel_classification_model(model,
                                                            3,
                                                            is_int=True,
                                                            n_features=5)
             model_onnx = convert_sklearn(
                 model,
                 initial_types=[('input',
                                 Int64TensorType([None, X.shape[1]]))],
                 target_opset=opset)
             self.assertNotIn('"Clip"', str(model_onnx))
             dump_data_and_model(X.astype(np.int64),
                                 model,
                                 model_onnx,
                                 basename="SklearnOVRRFMultiLabelInt64%d" %
                                 opset)