def test_lasso_alpha_warning(): X = [[-1], [0], [1]] Y = [-1, 0, 1] # just a straight line clf = Lasso(alpha=0) assert_warns(UserWarning, clf.fit, X, Y)
def test_alpha(): # Setting alpha=0 should not output nan results when p(x_i|y_j)=0 is a case X = np.array([[1, 0], [1, 1]]) y = np.array([0, 1]) nb = BernoulliNB(alpha=0.) assert_warns(UserWarning, nb.partial_fit, X, y, classes=[0, 1]) assert_warns(UserWarning, nb.fit, X, y) prob = np.array([[1, 0], [0, 1]]) assert_array_almost_equal(nb.predict_proba(X), prob) nb = MultinomialNB(alpha=0.) assert_warns(UserWarning, nb.partial_fit, X, y, classes=[0, 1]) assert_warns(UserWarning, nb.fit, X, y) prob = np.array([[2. / 3, 1. / 3], [0, 1]]) assert_array_almost_equal(nb.predict_proba(X), prob) nb = CategoricalNB(alpha=0.) assert_warns(UserWarning, nb.fit, X, y) prob = np.array([[1., 0.], [0., 1.]]) assert_array_almost_equal(nb.predict_proba(X), prob) # Test sparse X X = scipy.sparse.csr_matrix(X) nb = BernoulliNB(alpha=0.) assert_warns(UserWarning, nb.fit, X, y) prob = np.array([[1, 0], [0, 1]]) assert_array_almost_equal(nb.predict_proba(X), prob) nb = MultinomialNB(alpha=0.) assert_warns(UserWarning, nb.fit, X, y) prob = np.array([[2. / 3, 1. / 3], [0, 1]]) assert_array_almost_equal(nb.predict_proba(X), prob) # Test for alpha < 0 X = np.array([[1, 0], [1, 1]]) y = np.array([0, 1]) expected_msg = ('Smoothing parameter alpha = -1.0e-01. ' 'alpha should be > 0.') b_nb = BernoulliNB(alpha=-0.1) m_nb = MultinomialNB(alpha=-0.1) c_nb = CategoricalNB(alpha=-0.1) assert_raise_message(ValueError, expected_msg, b_nb.fit, X, y) assert_raise_message(ValueError, expected_msg, m_nb.fit, X, y) assert_raise_message(ValueError, expected_msg, c_nb.fit, X, y) b_nb = BernoulliNB(alpha=-0.1) m_nb = MultinomialNB(alpha=-0.1) assert_raise_message(ValueError, expected_msg, b_nb.partial_fit, X, y, classes=[0, 1]) assert_raise_message(ValueError, expected_msg, m_nb.partial_fit, X, y, classes=[0, 1])
def test_ledoit_wolf(): # Tests LedoitWolf module on a simple dataset. # test shrinkage coeff on a simple data set X_centered = X - X.mean(axis=0) lw = LedoitWolf(assume_centered=True) lw.fit(X_centered) shrinkage_ = lw.shrinkage_ score_ = lw.score(X_centered) assert_almost_equal( ledoit_wolf_shrinkage(X_centered, assume_centered=True), shrinkage_) assert_almost_equal( ledoit_wolf_shrinkage(X_centered, assume_centered=True, block_size=6), shrinkage_) # compare shrunk covariance obtained from data and from MLE estimate lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X_centered, assume_centered=True) assert_array_almost_equal(lw_cov_from_mle, lw.covariance_, 4) assert_almost_equal(lw_shrinkage_from_mle, lw.shrinkage_) # compare estimates given by LW and ShrunkCovariance scov = ShrunkCovariance(shrinkage=lw.shrinkage_, assume_centered=True) scov.fit(X_centered) assert_array_almost_equal(scov.covariance_, lw.covariance_, 4) # test with n_features = 1 X_1d = X[:, 0].reshape((-1, 1)) lw = LedoitWolf(assume_centered=True) lw.fit(X_1d) lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X_1d, assume_centered=True) assert_array_almost_equal(lw_cov_from_mle, lw.covariance_, 4) assert_almost_equal(lw_shrinkage_from_mle, lw.shrinkage_) assert_array_almost_equal((X_1d**2).sum() / n_samples, lw.covariance_, 4) # test shrinkage coeff on a simple data set (without saving precision) lw = LedoitWolf(store_precision=False, assume_centered=True) lw.fit(X_centered) assert_almost_equal(lw.score(X_centered), score_, 4) assert (lw.precision_ is None) # Same tests without assuming centered data # test shrinkage coeff on a simple data set lw = LedoitWolf() lw.fit(X) assert_almost_equal(lw.shrinkage_, shrinkage_, 4) assert_almost_equal(lw.shrinkage_, ledoit_wolf_shrinkage(X)) assert_almost_equal(lw.shrinkage_, ledoit_wolf(X)[1]) assert_almost_equal(lw.score(X), score_, 4) # compare shrunk covariance obtained from data and from MLE estimate lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X) assert_array_almost_equal(lw_cov_from_mle, lw.covariance_, 4) assert_almost_equal(lw_shrinkage_from_mle, lw.shrinkage_) # compare estimates given by LW and ShrunkCovariance scov = ShrunkCovariance(shrinkage=lw.shrinkage_) scov.fit(X) assert_array_almost_equal(scov.covariance_, lw.covariance_, 4) # test with n_features = 1 X_1d = X[:, 0].reshape((-1, 1)) lw = LedoitWolf() lw.fit(X_1d) lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X_1d) assert_array_almost_equal(lw_cov_from_mle, lw.covariance_, 4) assert_almost_equal(lw_shrinkage_from_mle, lw.shrinkage_) assert_array_almost_equal(empirical_covariance(X_1d), lw.covariance_, 4) # test with one sample # warning should be raised when using only 1 sample X_1sample = np.arange(5).reshape(1, 5) lw = LedoitWolf() assert_warns(UserWarning, lw.fit, X_1sample) assert_array_almost_equal(lw.covariance_, np.zeros(shape=(5, 5), dtype=np.float64)) # test shrinkage coeff on a simple data set (without saving precision) lw = LedoitWolf(store_precision=False) lw.fit(X) assert_almost_equal(lw.score(X), score_, 4) assert (lw.precision_ is None)
def test_oas(): # Tests OAS module on a simple dataset. # test shrinkage coeff on a simple data set X_centered = X - X.mean(axis=0) oa = OAS(assume_centered=True) oa.fit(X_centered) shrinkage_ = oa.shrinkage_ score_ = oa.score(X_centered) # compare shrunk covariance obtained from data and from MLE estimate oa_cov_from_mle, oa_shrinkage_from_mle = oas(X_centered, assume_centered=True) assert_array_almost_equal(oa_cov_from_mle, oa.covariance_, 4) assert_almost_equal(oa_shrinkage_from_mle, oa.shrinkage_) # compare estimates given by OAS and ShrunkCovariance scov = ShrunkCovariance(shrinkage=oa.shrinkage_, assume_centered=True) scov.fit(X_centered) assert_array_almost_equal(scov.covariance_, oa.covariance_, 4) # test with n_features = 1 X_1d = X[:, 0:1] oa = OAS(assume_centered=True) oa.fit(X_1d) oa_cov_from_mle, oa_shrinkage_from_mle = oas(X_1d, assume_centered=True) assert_array_almost_equal(oa_cov_from_mle, oa.covariance_, 4) assert_almost_equal(oa_shrinkage_from_mle, oa.shrinkage_) assert_array_almost_equal((X_1d**2).sum() / n_samples, oa.covariance_, 4) # test shrinkage coeff on a simple data set (without saving precision) oa = OAS(store_precision=False, assume_centered=True) oa.fit(X_centered) assert_almost_equal(oa.score(X_centered), score_, 4) assert (oa.precision_ is None) # Same tests without assuming centered data-------------------------------- # test shrinkage coeff on a simple data set oa = OAS() oa.fit(X) assert_almost_equal(oa.shrinkage_, shrinkage_, 4) assert_almost_equal(oa.score(X), score_, 4) # compare shrunk covariance obtained from data and from MLE estimate oa_cov_from_mle, oa_shrinkage_from_mle = oas(X) assert_array_almost_equal(oa_cov_from_mle, oa.covariance_, 4) assert_almost_equal(oa_shrinkage_from_mle, oa.shrinkage_) # compare estimates given by OAS and ShrunkCovariance scov = ShrunkCovariance(shrinkage=oa.shrinkage_) scov.fit(X) assert_array_almost_equal(scov.covariance_, oa.covariance_, 4) # test with n_features = 1 X_1d = X[:, 0].reshape((-1, 1)) oa = OAS() oa.fit(X_1d) oa_cov_from_mle, oa_shrinkage_from_mle = oas(X_1d) assert_array_almost_equal(oa_cov_from_mle, oa.covariance_, 4) assert_almost_equal(oa_shrinkage_from_mle, oa.shrinkage_) assert_array_almost_equal(empirical_covariance(X_1d), oa.covariance_, 4) # test with one sample # warning should be raised when using only 1 sample X_1sample = np.arange(5).reshape(1, 5) oa = OAS() assert_warns(UserWarning, oa.fit, X_1sample) assert_array_almost_equal(oa.covariance_, np.zeros(shape=(5, 5), dtype=np.float64)) # test shrinkage coeff on a simple data set (without saving precision) oa = OAS(store_precision=False) oa.fit(X) assert_almost_equal(oa.score(X), score_, 4) assert (oa.precision_ is None)
def test_timeout(): a = svm.SVC(kernel=lambda x, y: np.dot(x, y.T), probability=True, random_state=0, max_iter=1) assert_warns(ConvergenceWarning, a.fit, np.array(X), Y)
def test_f_classif_constant_feature(): # Test that f_classif warns if a feature is constant throughout. X, y = make_classification(n_samples=10, n_features=5) X[:, 0] = 2.0 assert_warns(UserWarning, f_classif, X, y)
def test_factor_analysis(): # Test FactorAnalysis ability to recover the data covariance structure rng = np.random.RandomState(0) n_samples, n_features, n_components = 20, 5, 3 # Some random settings for the generative model W = rng.randn(n_components, n_features) # latent variable of dim 3, 20 of it h = rng.randn(n_samples, n_components) # using gamma to model different noise variance # per component noise = rng.gamma(1, size=n_features) * rng.randn(n_samples, n_features) # generate observations # wlog, mean is 0 X = np.dot(h, W) + noise with pytest.raises(ValueError): FactorAnalysis(svd_method='foo') fa_fail = FactorAnalysis() fa_fail.svd_method = 'foo' with pytest.raises(ValueError): fa_fail.fit(X) fas = [] for method in ['randomized', 'lapack']: fa = FactorAnalysis(n_components=n_components, svd_method=method) fa.fit(X) fas.append(fa) X_t = fa.transform(X) assert X_t.shape == (n_samples, n_components) assert_almost_equal(fa.loglike_[-1], fa.score_samples(X).sum()) assert_almost_equal(fa.score_samples(X).mean(), fa.score(X)) diff = np.all(np.diff(fa.loglike_)) assert diff > 0., 'Log likelihood dif not increase' # Sample Covariance scov = np.cov(X, rowvar=0., bias=1.) # Model Covariance mcov = fa.get_covariance() diff = np.sum(np.abs(scov - mcov)) / W.size assert diff < 0.1, "Mean absolute difference is %f" % diff fa = FactorAnalysis(n_components=n_components, noise_variance_init=np.ones(n_features)) with pytest.raises(ValueError): fa.fit(X[:, :2]) f = lambda x, y: np.abs(getattr(x, y)) # sign will not be equal fa1, fa2 = fas for attr in ['loglike_', 'components_', 'noise_variance_']: assert_almost_equal(f(fa1, attr), f(fa2, attr)) fa1.max_iter = 1 fa1.verbose = True assert_warns(ConvergenceWarning, fa1.fit, X) # Test get_covariance and get_precision with n_components == n_features # with n_components < n_features and with n_components == 0 for n_components in [0, 2, X.shape[1]]: fa.n_components = n_components fa.fit(X) cov = fa.get_covariance() precision = fa.get_precision() assert_array_almost_equal(np.dot(cov, precision), np.eye(X.shape[1]), 12) # test rotation n_components = 2 results, projections = {}, {} for method in (None, "varimax", 'quartimax'): fa_var = FactorAnalysis(n_components=n_components, rotation=method) results[method] = fa_var.fit_transform(X) projections[method] = fa_var.get_covariance() for rot1, rot2 in combinations([None, 'varimax', 'quartimax'], 2): assert not np.allclose(results[rot1], results[rot2]) assert np.allclose(projections[rot1], projections[rot2], atol=3) assert_raises(ValueError, FactorAnalysis(rotation='not_implemented').fit_transform, X) # test against R's psych::principal with rotate="varimax" # (i.e., the values below stem from rotating the components in R) # R's factor analysis returns quite different values; therefore, we only # test the rotation itself factors = np.array( [[0.89421016, -0.35854928, -0.27770122, 0.03773647], [-0.45081822, -0.89132754, 0.0932195, -0.01787973], [0.99500666, -0.02031465, 0.05426497, -0.11539407], [0.96822861, -0.06299656, 0.24411001, 0.07540887]]) r_solution = np.array([[0.962, 0.052], [-0.141, 0.989], [0.949, -0.300], [0.937, -0.251]]) rotated = _ortho_rotation(factors[:, :n_components], method='varimax').T assert_array_almost_equal(np.abs(rotated), np.abs(r_solution), decimal=3)
def test_xfail_ignored_in_check_estimator(): # Make sure checks marked as xfail are just ignored and not run by # check_estimator(), but still raise a warning. assert_warns(SkipTestWarning, check_estimator, NuSVC())
def test_check_estimator(): # tests that the estimator actually fails on "bad" estimators. # not a complete test of all checks, which are very extensive. # check that we have a set_params and can clone msg = "Passing a class was deprecated" assert_raises_regex(TypeError, msg, check_estimator, object) msg = ("Parameter 'p' of estimator 'HasMutableParameters' is of type " "object which is not allowed") # check that the "default_constructible" test checks for mutable parameters check_estimator(HasImmutableParameters()) # should pass assert_raises_regex(AssertionError, msg, check_estimator, HasMutableParameters()) # check that values returned by get_params match set_params msg = "get_params result does not match what was passed to set_params" assert_raises_regex(AssertionError, msg, check_estimator, ModifiesValueInsteadOfRaisingError()) assert_warns(UserWarning, check_estimator, RaisesErrorInSetParams()) assert_raises_regex(AssertionError, msg, check_estimator, ModifiesAnotherValue()) # check that we have a fit method msg = "object has no attribute 'fit'" assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator()) # check that fit does input validation msg = "Did not raise" assert_raises_regex(AssertionError, msg, check_estimator, BaseBadClassifier()) # check that sample_weights in fit accepts pandas.Series type try: from pandas import Series # noqa msg = ("Estimator NoSampleWeightPandasSeriesType raises error if " "'sample_weight' parameter is of type pandas.Series") assert_raises_regex(ValueError, msg, check_estimator, NoSampleWeightPandasSeriesType()) except ImportError: pass # check that predict does input validation (doesn't accept dicts in input) msg = "Estimator doesn't check for NaN and inf in predict" assert_raises_regex(AssertionError, msg, check_estimator, NoCheckinPredict()) # check that estimator state does not change # at transform/predict/predict_proba time msg = 'Estimator changes __dict__ during predict' assert_raises_regex(AssertionError, msg, check_estimator, ChangesDict()) # check that `fit` only changes attribures that # are private (start with an _ or end with a _). msg = ('Estimator ChangesWrongAttribute should not change or mutate ' 'the parameter wrong_attribute from 0 to 1 during fit.') assert_raises_regex(AssertionError, msg, check_estimator, ChangesWrongAttribute()) check_estimator(ChangesUnderscoreAttribute()) # check that `fit` doesn't add any public attribute msg = (r'Estimator adds public attribute\(s\) during the fit method.' ' Estimators are only allowed to add private attributes' ' either started with _ or ended' ' with _ but wrong_attribute added') assert_raises_regex(AssertionError, msg, check_estimator, SetsWrongAttribute()) # check for sample order invariance name = NotInvariantSampleOrder.__name__ method = 'predict' msg = ("{method} of {name} is not invariant when applied to a dataset" "with different sample order.").format(method=method, name=name) assert_raises_regex(AssertionError, msg, check_estimator, NotInvariantSampleOrder()) # check for invariant method name = NotInvariantPredict.__name__ method = 'predict' msg = ("{method} of {name} is not invariant when applied " "to a subset.").format(method=method, name=name) assert_raises_regex(AssertionError, msg, check_estimator, NotInvariantPredict()) # check for sparse matrix input handling name = NoSparseClassifier.__name__ msg = "Estimator %s doesn't seem to fail gracefully on sparse area_data" % name assert_raises_regex(AssertionError, msg, check_estimator, NoSparseClassifier()) # Large indices test on bad estimator msg = ('Estimator LargeSparseNotSupportedClassifier doesn\'t seem to ' r'support \S{3}_64 matrix, and is not failing gracefully.*') assert_raises_regex(AssertionError, msg, check_estimator, LargeSparseNotSupportedClassifier()) # does error on binary_only untagged estimator msg = 'Only 2 classes are supported' assert_raises_regex(ValueError, msg, check_estimator, UntaggedBinaryClassifier()) # non-regression test for estimators transforming to sparse area_data check_estimator(SparseTransformer()) # doesn't error on actual estimator check_estimator(LogisticRegression()) check_estimator(LogisticRegression(C=0.01)) check_estimator(MultiTaskElasticNet()) # doesn't error on binary_only tagged estimator check_estimator(TaggedBinaryClassifier()) # Check regressor with requires_positive_y estimator tag msg = 'negative y values not supported!' assert_raises_regex(ValueError, msg, check_estimator, RequiresPositiveYRegressor()) # Does not raise error on classifier with poor_score tag check_estimator(PoorScoreLogisticRegression())
def test_check_estimator(): # tests that the estimator actually fails on "bad" estimators. # not a complete test of all checks, which are very extensive. # check that we have a set_params and can clone msg = "it does not implement a 'get_params' method" assert_raises_regex(TypeError, msg, check_estimator, object) msg = "object has no attribute '_get_tags'" assert_raises_regex(AttributeError, msg, check_estimator, object()) # check that values returned by get_params match set_params msg = "get_params result does not match what was passed to set_params" assert_raises_regex(AssertionError, msg, check_estimator, ModifiesValueInsteadOfRaisingError()) assert_warns(UserWarning, check_estimator, RaisesErrorInSetParams()) assert_raises_regex(AssertionError, msg, check_estimator, ModifiesAnotherValue()) # check that we have a fit method msg = "object has no attribute 'fit'" assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator) assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator()) # check that fit does input validation msg = "ValueError not raised" assert_raises_regex(AssertionError, msg, check_estimator, BaseBadClassifier) assert_raises_regex(AssertionError, msg, check_estimator, BaseBadClassifier()) # check that sample_weights in fit accepts pandas.Series type try: from pandas import Series # noqa msg = ("Estimator NoSampleWeightPandasSeriesType raises error if " "'sample_weight' parameter is of type pandas.Series") assert_raises_regex( ValueError, msg, check_estimator, NoSampleWeightPandasSeriesType) except ImportError: pass # check that predict does input validation (doesn't accept dicts in input) msg = "Estimator doesn't check for NaN and inf in predict" assert_raises_regex(AssertionError, msg, check_estimator, NoCheckinPredict) assert_raises_regex(AssertionError, msg, check_estimator, NoCheckinPredict()) # check that estimator state does not change # at transform/predict/predict_proba time msg = 'Estimator changes __dict__ during predict' assert_raises_regex(AssertionError, msg, check_estimator, ChangesDict) # check that `fit` only changes attribures that # are private (start with an _ or end with a _). msg = ('Estimator ChangesWrongAttribute should not change or mutate ' 'the parameter wrong_attribute from 0 to 1 during fit.') assert_raises_regex(AssertionError, msg, check_estimator, ChangesWrongAttribute) check_estimator(ChangesUnderscoreAttribute) # check that `fit` doesn't add any public attribute msg = (r'Estimator adds public attribute\(s\) during the fit method.' ' Estimators are only allowed to add private attributes' ' either started with _ or ended' ' with _ but wrong_attribute added') assert_raises_regex(AssertionError, msg, check_estimator, SetsWrongAttribute) # check for invariant method name = NotInvariantPredict.__name__ method = 'predict' msg = ("{method} of {name} is not invariant when applied " "to a subset.").format(method=method, name=name) assert_raises_regex(AssertionError, msg, check_estimator, NotInvariantPredict) # check for sparse matrix input handling name = NoSparseClassifier.__name__ msg = "Estimator %s doesn't seem to fail gracefully on sparse data" % name # the check for sparse input handling prints to the stdout, # instead of raising an error, so as not to remove the original traceback. # that means we need to jump through some hoops to catch it. old_stdout = sys.stdout string_buffer = StringIO() sys.stdout = string_buffer try: check_estimator(NoSparseClassifier) except: pass finally: sys.stdout = old_stdout assert msg in string_buffer.getvalue() # Large indices test on bad estimator msg = ('Estimator LargeSparseNotSupportedClassifier doesn\'t seem to ' r'support \S{3}_64 matrix, and is not failing gracefully.*') assert_raises_regex(AssertionError, msg, check_estimator, LargeSparseNotSupportedClassifier) # does error on binary_only untagged estimator msg = 'Only 2 classes are supported' assert_raises_regex(ValueError, msg, check_estimator, UntaggedBinaryClassifier) # non-regression test for estimators transforming to sparse data check_estimator(SparseTransformer()) # doesn't error on actual estimator check_estimator(LogisticRegression) check_estimator(LogisticRegression(C=0.01)) check_estimator(MultiTaskElasticNet) check_estimator(MultiTaskElasticNet()) # doesn't error on binary_only tagged estimator check_estimator(TaggedBinaryClassifier) # Check regressor with requires_positive_y estimator tag msg = 'negative y values not supported!' assert_raises_regex(ValueError, msg, check_estimator, RequiresPositiveYRegressor)
def test_ignore_warning(): # This check that ignore_warning decorateur and context manager are working # as expected def _warning_function(): warnings.warn("deprecation warning", DeprecationWarning) def _multiple_warning_function(): warnings.warn("deprecation warning", DeprecationWarning) warnings.warn("deprecation warning") # Check the function directly assert_no_warnings(ignore_warnings(_warning_function)) assert_no_warnings( ignore_warnings(_warning_function, category=DeprecationWarning)) assert_warns(DeprecationWarning, ignore_warnings(_warning_function, category=UserWarning)) assert_warns( UserWarning, ignore_warnings(_multiple_warning_function, category=FutureWarning)) assert_warns( DeprecationWarning, ignore_warnings(_multiple_warning_function, category=UserWarning)) assert_no_warnings( ignore_warnings(_warning_function, category=(DeprecationWarning, UserWarning))) # Check the decorator @ignore_warnings def decorator_no_warning(): _warning_function() _multiple_warning_function() @ignore_warnings(category=(DeprecationWarning, UserWarning)) def decorator_no_warning_multiple(): _multiple_warning_function() @ignore_warnings(category=DeprecationWarning) def decorator_no_deprecation_warning(): _warning_function() @ignore_warnings(category=UserWarning) def decorator_no_user_warning(): _warning_function() @ignore_warnings(category=DeprecationWarning) def decorator_no_deprecation_multiple_warning(): _multiple_warning_function() @ignore_warnings(category=UserWarning) def decorator_no_user_multiple_warning(): _multiple_warning_function() assert_no_warnings(decorator_no_warning) assert_no_warnings(decorator_no_warning_multiple) assert_no_warnings(decorator_no_deprecation_warning) assert_warns(DeprecationWarning, decorator_no_user_warning) assert_warns(UserWarning, decorator_no_deprecation_multiple_warning) assert_warns(DeprecationWarning, decorator_no_user_multiple_warning) # Check the context manager def context_manager_no_warning(): with ignore_warnings(): _warning_function() def context_manager_no_warning_multiple(): with ignore_warnings(category=(DeprecationWarning, UserWarning)): _multiple_warning_function() def context_manager_no_deprecation_warning(): with ignore_warnings(category=DeprecationWarning): _warning_function() def context_manager_no_user_warning(): with ignore_warnings(category=UserWarning): _warning_function() def context_manager_no_deprecation_multiple_warning(): with ignore_warnings(category=DeprecationWarning): _multiple_warning_function() def context_manager_no_user_multiple_warning(): with ignore_warnings(category=UserWarning): _multiple_warning_function() assert_no_warnings(context_manager_no_warning) assert_no_warnings(context_manager_no_warning_multiple) assert_no_warnings(context_manager_no_deprecation_warning) assert_warns(DeprecationWarning, context_manager_no_user_warning) assert_warns(UserWarning, context_manager_no_deprecation_multiple_warning) assert_warns(DeprecationWarning, context_manager_no_user_multiple_warning) # Check that passing warning class as first positional argument warning_class = UserWarning match = "'obj' should be a callable.+you should use 'category=UserWarning'" with pytest.raises(ValueError, match=match): silence_warnings_func = ignore_warnings(warning_class)( _warning_function) silence_warnings_func() with pytest.raises(ValueError, match=match): @ignore_warnings(warning_class) def test(): pass