コード例 #1
0
 def test_model_extra_trees_classifier_multilabel(self):
     model, X_test = fit_multilabel_classification_model(
         ExtraTreesClassifier(random_state=42, n_estimators=10))
     options = {id(model): {'zipmap': False}}
     model_onnx = convert_sklearn(
         model,
         "scikit-learn ExtraTreesClassifier",
         [("input", FloatTensorType([None, X_test.shape[1]]))],
         options=options,
         target_opset=get_opset_number_from_onnx())
     self.assertTrue(model_onnx is not None)
     self.assertNotIn('zipmap', str(model_onnx).lower())
     dump_data_and_model(
         X_test,
         model,
         model_onnx,
         basename="SklearnExtraTreesClassifierMultiLabel-Out0",
         folder=self.folder)
コード例 #2
0
 def test_model_random_forest_classifier_multilabel_low_samples(self):
     model, X_test = fit_multilabel_classification_model(
         RandomForestClassifier(random_state=42, n_estimators=10),
         n_samples=4)
     options = {id(model): {'zipmap': False}}
     model_onnx = convert_sklearn(
         model,
         "scikit-learn RandomForestClassifier",
         [("input", FloatTensorType([None, X_test.shape[1]]))],
         options=options,
         target_opset=TARGET_OPSET)
     self.assertTrue(model_onnx is not None)
     self.assertNotIn('zipmap', str(model_onnx).lower())
     dump_data_and_model(
         X_test,
         model,
         model_onnx,
         basename="SklearnRandomForestClassifierMultiLabelLowSamples-Out0",
         folder=self.folder)
コード例 #3
0
 def test_random_forest_classifier_fit_simple_multi(self):
     res = fit_multilabel_classification_model(ExtraTreesClassifier())
     self.assertEqual(len(res), 2)