예제 #1
0
def test_basic(X, y, refit):
    # test on iris
    ec = SimpleClassifier(refit=refit)
    ec.fit(X, y)
    if refit:
        # smoke test
        ec.predict(X)
    else:
        with pytest.raises(ValueError, match="refit"):
            ec.predict(X)
예제 #2
0
def test_deletation_simple(monkeypatch, model):
    def mock_get_estimators_linearsvc(self):
        return [model]

    monkeypatch.setattr(SimpleClassifier, '_get_estimators',
                        mock_get_estimators_linearsvc)
    sc = SimpleClassifier(random_state=0)
    sc.fit(X_blobs, y_blobs)
    assert isinstance(sc.est_[1], type(model))
    assert (hasattr(sc,
                    'decision_function') == hasattr(model,
                                                    'decision_function'))
    assert hasattr(sc, 'predict_proba') == hasattr(model, 'predict_proba')
예제 #3
0
파일: test_models.py 프로젝트: pdhinwa/dabl
def test_delegation_simple(monkeypatch, model):
    def mock_get_estimators(self):
        return [model]

    monkeypatch.setattr(SimpleClassifier, '_get_estimators',
                        mock_get_estimators)
    sc = SimpleClassifier(random_state=0)
    sc.fit(X_blobs, y_blobs)
    assert isinstance(sc.est_[1], type(model))
    assert (hasattr(sc,
                    'decision_function') == hasattr(model,
                                                    'decision_function'))
    assert hasattr(sc, 'predict_proba') == hasattr(model, 'predict_proba')
    if hasattr(sc, 'predict_proba'):
        assert sc.predict_proba(X_blobs).shape == (X_blobs.shape[0], 2)
    if hasattr(sc, 'decision_function'):
        assert sc.decision_function(X_blobs).shape == (X_blobs.shape[0], )
예제 #4
0
def test_classification_of_string_targets():
    X = np.array([1, 2, 1, 2, 1, 2, 1, 2, 1, 2])
    y = np.array(['a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', 'b'])
    obj = SimpleClassifier()

    fitted = obj.fit(X, y)
    pred = fitted.predict(np.array([1, 2]).reshape(-1, 1))

    np.testing.assert_array_equal(obj.classes_, np.array(['a', 'b']))
    np.testing.assert_array_equal(pred, np.array(['a', 'b']))
예제 #5
0
def test_dataframe():
    path = os.path.dirname(__file__)
    titanic = pd.read_csv(os.path.join(path, '../datasets/titanic.csv'))[::10]
    ec = SimpleClassifier()
    ec.fit(titanic, target_col='survived')
예제 #6
0
파일: test_models.py 프로젝트: pdhinwa/dabl
def test_simple_classifier_titanic():
    titanic = load_titanic()
    ec = SimpleClassifier()
    ec.fit(titanic, target_col='survived')
    ec.predict(titanic.drop('survived', axis=1))
예제 #7
0
"""
Model Explanation
=================
"""
from dabl.models import SimpleClassifier
from dabl.explain import explain
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split

wine = load_wine()

X_train, X_test, y_train, y_test = train_test_split(wine.data, wine.target)

sc = SimpleClassifier()

sc.fit(X_train, y_train)

explain(sc, X_test, y_test)
예제 #8
0
def test_simple_classifier_titanic():
    titanic = load_titanic()[::10]
    ec = SimpleClassifier()
    ec.fit(titanic, target_col='survived')