예제 #1
0
 def test_pass_predict_proba_multiclass_3class(self):
     clf = FastLinearClassifier(number_of_threads=1)
     clf.fit(X_train_3class, y_train_3class)
     s = clf.predict_proba(X_test_3class).sum()
     assert_almost_equal(s,
                         38.0,
                         decimal=4,
                         err_msg=invalid_decision_function_output)
     assert_equal(set(clf.classes_), {'Blue', 'Green', 'Red'})
예제 #2
0
 def test_pass_predict_proba_multiclass_3class_retains_classes_type(self):
     clf = FastLinearClassifier(number_of_threads=1)
     clf.fit(X_train_3class_int, y_train_3class_int)
     s = clf.predict_proba(X_test_3class_int).sum()
     assert_almost_equal(s,
                         38.0,
                         decimal=4,
                         err_msg=invalid_predict_proba_output)
     assert_equal(set(clf.classes_), {0, 1, 2})
예제 #3
0
    def test_predict_proba_multiclass_3class_no_y_input_implies_no_classes_attribute(
            self):
        X_train = X_train_3class_int.join(y_train_3class_int)
        X_test = X_test_3class_int.join(y_test_3class_int)

        clf = FastLinearClassifier(number_of_threads=1, label='Label')
        clf.fit(X_train)

        if hasattr(clf, 'classes_'):
            # The classes_ attribute is currently not supported
            # when fitting when there is no y input specified.
            self.fail("classes_ attribute not expected.")

        s = clf.predict_proba(X_test).sum()
        assert_almost_equal(s,
                            38.0,
                            decimal=4,
                            err_msg=invalid_predict_proba_output)

        if hasattr(clf, 'classes_'):
            # The classes_ attribute is currently not supported
            # when predicting when there was no y input specified
            # during fitting.
            self.fail("classes_ attribute not expected.")