def test_base_estimator_meta_estimator():
    # Check that a meta-estimator relying on an estimator implementing
    # `predict_proba` will work even if it does expose this method before being
    # fitted.
    # Non-regression test for:
    # https://github.com/scikit-learn/scikit-learn/issues/19119

    base_estimator = StackingClassifier(
        estimators=[
            ("svc_1", SVC(probability=True)),
            ("svc_2", SVC(probability=True)),
        ],
        final_estimator=SVC(probability=True),
        cv=2,
    )

    assert hasattr(base_estimator, "predict_proba")
    clf = SelfTrainingClassifier(base_estimator=base_estimator)
    clf.fit(X_train, y_train_missing_labels)
    clf.predict_proba(X_test)

    base_estimator = StackingClassifier(
        estimators=[
            ("svc_1", SVC(probability=False)),
            ("svc_2", SVC(probability=False)),
        ],
        final_estimator=SVC(probability=False),
        cv=2,
    )

    assert not hasattr(base_estimator, "predict_proba")
    clf = SelfTrainingClassifier(base_estimator=base_estimator)
    with pytest.raises(AttributeError):
        clf.fit(X_train, y_train_missing_labels)
示例#2
0
    def self_training_clf(self, base_classifier, X_train, y_train, 
                            threshold= None, max_iter = None,verbose = None):
        """
        Train self-training classifier from scikit-learn >= 0.24.1

        Parameters
        ___________
        base_classifier: Supervised classifier implementing both fit and predict_proba
        X_train: Scaled feature matrix of the training set
        y_train: Class label of the training set
        threshold (float):  The decision threshold for use with criterion='threshold'. Should be in [0, 1)
        max_iter (int):  Maximum number of iterations allowed. Should be greater than or equal to 0
        verbose (bool): Enable verbose output

        Returns
        _____________
        Predicted labels and probability
        """
        # Self training model
        model = SelfTrainingClassifier(base_classifier,threshold= threshold, 
                            max_iter = max_iter, verbose = verbose)

        # Fit the training set
        model.fit(X_train, y_train)

        # Predict the labels of the unlabeled data points
        predicted_labels = model.predict(X_train)

        # Predict probability
        predicted_proba = model.predict_proba(X_train)
        return predicted_labels, predicted_proba
def test_classification(base_estimator, selection_crit):
    # Check classification for various parameter settings.
    # Also assert that predictions for strings and numerical labels are equal.
    # Also test for multioutput classification
    threshold = 0.75
    max_iter = 10
    st = SelfTrainingClassifier(base_estimator,
                                max_iter=max_iter,
                                threshold=threshold,
                                criterion=selection_crit)
    st.fit(X_train, y_train_missing_labels)
    pred = st.predict(X_test)
    proba = st.predict_proba(X_test)

    st_string = SelfTrainingClassifier(base_estimator,
                                       max_iter=max_iter,
                                       criterion=selection_crit,
                                       threshold=threshold)
    st_string.fit(X_train, y_train_missing_strings)
    pred_string = st_string.predict(X_test)
    proba_string = st_string.predict_proba(X_test)

    assert_array_equal(np.vectorize(mapping.get)(pred), pred_string)
    assert_array_equal(proba, proba_string)

    assert st.termination_condition_ == st_string.termination_condition_
    # Check consistency between labeled_iter, n_iter and max_iter
    labeled = y_train_missing_labels != -1
    # assert that labeled samples have labeled_iter = 0
    assert_array_equal(st.labeled_iter_ == 0, labeled)
    # assert that labeled samples do not change label during training
    assert_array_equal(y_train_missing_labels[labeled],
                       st.transduction_[labeled])

    # assert that the max of the iterations is less than the total amount of
    # iterations
    assert np.max(st.labeled_iter_) <= st.n_iter_ <= max_iter
    assert np.max(st_string.labeled_iter_) <= st_string.n_iter_ <= max_iter

    # check shapes
    assert st.labeled_iter_.shape == st.transduction_.shape
    assert st_string.labeled_iter_.shape == st_string.transduction_.shape