def test_basic(X, y, refit): # test on iris ec = SimpleClassifier(refit=refit) ec.fit(X, y) if refit: # smoke test ec.predict(X) else: with pytest.raises(ValueError, match="refit"): ec.predict(X)
def test_deletation_simple(monkeypatch, model): def mock_get_estimators_linearsvc(self): return [model] monkeypatch.setattr(SimpleClassifier, '_get_estimators', mock_get_estimators_linearsvc) sc = SimpleClassifier(random_state=0) sc.fit(X_blobs, y_blobs) assert isinstance(sc.est_[1], type(model)) assert (hasattr(sc, 'decision_function') == hasattr(model, 'decision_function')) assert hasattr(sc, 'predict_proba') == hasattr(model, 'predict_proba')
def test_delegation_simple(monkeypatch, model): def mock_get_estimators(self): return [model] monkeypatch.setattr(SimpleClassifier, '_get_estimators', mock_get_estimators) sc = SimpleClassifier(random_state=0) sc.fit(X_blobs, y_blobs) assert isinstance(sc.est_[1], type(model)) assert (hasattr(sc, 'decision_function') == hasattr(model, 'decision_function')) assert hasattr(sc, 'predict_proba') == hasattr(model, 'predict_proba') if hasattr(sc, 'predict_proba'): assert sc.predict_proba(X_blobs).shape == (X_blobs.shape[0], 2) if hasattr(sc, 'decision_function'): assert sc.decision_function(X_blobs).shape == (X_blobs.shape[0], )
def test_classification_of_string_targets(): X = np.array([1, 2, 1, 2, 1, 2, 1, 2, 1, 2]) y = np.array(['a', 'b', 'a', 'b', 'a', 'b', 'a', 'b', 'a', 'b']) obj = SimpleClassifier() fitted = obj.fit(X, y) pred = fitted.predict(np.array([1, 2]).reshape(-1, 1)) np.testing.assert_array_equal(obj.classes_, np.array(['a', 'b'])) np.testing.assert_array_equal(pred, np.array(['a', 'b']))
def test_dataframe(): path = os.path.dirname(__file__) titanic = pd.read_csv(os.path.join(path, '../datasets/titanic.csv'))[::10] ec = SimpleClassifier() ec.fit(titanic, target_col='survived')
def test_simple_classifier_titanic(): titanic = load_titanic() ec = SimpleClassifier() ec.fit(titanic, target_col='survived') ec.predict(titanic.drop('survived', axis=1))
""" Model Explanation ================= """ from dabl.models import SimpleClassifier from dabl.explain import explain from sklearn.datasets import load_wine from sklearn.model_selection import train_test_split wine = load_wine() X_train, X_test, y_train, y_test = train_test_split(wine.data, wine.target) sc = SimpleClassifier() sc.fit(X_train, y_train) explain(sc, X_test, y_test)
def test_simple_classifier_titanic(): titanic = load_titanic()[::10] ec = SimpleClassifier() ec.fit(titanic, target_col='survived')