def test_check_estimator_clones(): # check that check_estimator doesn't modify the estimator it receives from sklearn.datasets import load_iris iris = load_iris() for Estimator in [ GaussianMixture, LinearRegression, RandomForestClassifier, NMF, SGDClassifier, MiniBatchKMeans ]: with ignore_warnings(category=FutureWarning): # when 'est = SGDClassifier()' est = Estimator() _set_checking_parameters(est) set_random_state(est) # without fitting old_hash = joblib.hash(est) check_estimator(est) assert old_hash == joblib.hash(est) with ignore_warnings(category=FutureWarning): # when 'est = SGDClassifier()' est = Estimator() _set_checking_parameters(est) set_random_state(est) # with fitting est.fit(iris.data + 10, iris.target) old_hash = joblib.hash(est) check_estimator(est) assert old_hash == joblib.hash(est)
def test_estimators(estimator, check, request): # Common tests for estimator instances with ignore_warnings(category=(FutureWarning, ConvergenceWarning, UserWarning, FutureWarning)): _set_checking_parameters(estimator) check(estimator)
def test_transformers_get_feature_names_out(transformer): _set_checking_parameters(transformer) with ignore_warnings(category=(FutureWarning)): check_transformer_get_feature_names_out(transformer.__class__.__name__, transformer) check_transformer_get_feature_names_out_pandas( transformer.__class__.__name__, transformer)
def test_check_param_validation(estimator): name = estimator.__class__.__name__ if name in PARAM_VALIDATION_ESTIMATORS_TO_IGNORE: pytest.skip( f"Skipping check_param_validation for {name}: Does not use the " "appropriate API for parameter validation yet.") _set_checking_parameters(estimator) check_param_validation(name, estimator)
def test_pandas_column_name_consistency(estimator): _set_checking_parameters(estimator) with ignore_warnings(category=(FutureWarning)): with pytest.warns(None) as record: check_dataframe_column_names_consistency( estimator.__class__.__name__, estimator) for warning in record: assert "was fitted without feature names" not in str( warning.message)
def test_estimators(estimator, check, request): # Common tests for estimator instances with ignore_warnings(category=(FutureWarning, ConvergenceWarning, UserWarning, FutureWarning)): _set_checking_parameters(estimator) xfail_checks = _safe_tags(estimator, '_xfail_test') check_name = _set_check_estimator_ids(check) if xfail_checks: if check_name in xfail_checks: msg = xfail_checks[check_name] request.applymarker(pytest.mark.xfail(reason=msg)) check(estimator)
def test_check_n_features_in_after_fitting(estimator): _set_checking_parameters(estimator) check_n_features_in_after_fitting(estimator.__class__.__name__, estimator)
def test_pandas_column_name_consistency(estimator): _set_checking_parameters(estimator) with ignore_warnings(category=(FutureWarning)): check_dataframe_column_names_consistency(estimator.__class__.__name__, estimator)