def check_target_type(name, Estimator): X = np.random.random((20, 2)) y = np.linspace(0, 1, 20) estimator = Estimator() set_random_state(estimator) with warns(UserWarning, match='should be of types'): estimator.fit(X, y)
def test_sampling_strategy_dict_over_sampling(): y = np.array([1] * 50 + [2] * 100 + [3] * 25) sampling_strategy = {1: 70, 2: 140, 3: 70} expected_msg = (r"After over-sampling, the number of samples \(140\) in" r" class 2 will be larger than the number of samples in" r" the majority class \(class #2 -> 100\)") with warns(UserWarning, expected_msg): check_sampling_strategy(sampling_strategy, y, 'over-sampling')
def test_ratio_dict_over_sampling(): y = np.array([1] * 50 + [2] * 100 + [3] * 25) ratio = {1: 70, 2: 100, 3: 70} ratio_ = check_ratio(ratio, y, 'over-sampling') assert ratio_ == {1: 20, 2: 0, 3: 45} ratio = {1: 70, 2: 140, 3: 70} expected_msg = ("After over-sampling, the number of samples \(140\) in" " class 2 will be larger than the number of samples in the" " majority class \(class #2 -> 100\)") with warns(UserWarning, expected_msg): check_ratio(ratio, y, 'over-sampling')
def test_warns(): import warnings with warns(UserWarning, match=r'must be \d+$'): warnings.warn("value must be 42", UserWarning) with raises(AssertionError, match='pattern not found'): with warns(UserWarning, match=r'must be \d+$'): warnings.warn("this is not here", UserWarning) with warns(UserWarning, match=r'aaa'): warnings.warn("cccccccccc", UserWarning) warnings.warn("bbbbbbbbbb", UserWarning) warnings.warn("aaaaaaaaaa", UserWarning) a, b, c = ('aaa', 'bbbbbbbbbb', 'cccccccccc') expected_msg = "'{}' pattern not found in \['{}', '{}'\]".format(a, b, c) with raises(AssertionError, match=expected_msg): with warns(UserWarning, match=r'aaa'): warnings.warn("bbbbbbbbbb", UserWarning) warnings.warn("cccccccccc", UserWarning)
def test_sensitivity_specificity_unused_pos_label(): # but average != 'binary'; even if data is binary with warns(UserWarning, "use labels=\[pos_label\] to specify a single"): sensitivity_specificity_support( [1, 2, 1], [1, 2, 2], pos_label=2, average='macro')
def test_deprecation_random_state(): tl = TomekLinks(random_state=0) with warns( DeprecationWarning, match="'random_state' is deprecated from 0.4"): tl.fit_resample(X, Y)
def test_deprecation_random_state(): renn = RepeatedEditedNearestNeighbours(random_state=0) with warns(DeprecationWarning, match="'random_state' is deprecated from 0.4"): renn.fit_sample(X, Y)
def test_deprecation_random_state(): nm = NearMiss(random_state=0) with warns( DeprecationWarning, match="'random_state' is deprecated from 0.4"): nm.fit_resample(X, Y)
def test_deprecation_random_state(): allknn = AllKNN(random_state=0) with warns( DeprecationWarning, match="'random_state' is deprecated from 0.4"): allknn.fit_resample(X, Y)
def test_deprecation_random_state(): allknn = AllKNN(random_state=0) with warns(DeprecationWarning, match="'random_state' is deprecated from 0.4"): allknn.fit_resample(X, Y)
def test_deprecation_random_state(): ncr = NeighbourhoodCleaningRule(random_state=0) with warns( DeprecationWarning, match="'random_state' is deprecated from 0.4"): ncr.fit_resample(X, Y)
def test_deprecation_random_state(): renn = RepeatedEditedNearestNeighbours(random_state=0) with warns(DeprecationWarning, match="'random_state' is deprecated from 0.4"): renn.fit_resample(X, Y)
def test_deprecation_random_state(): tl = TomekLinks(random_state=0) with warns(DeprecationWarning, match="'random_state' is deprecated from 0.4"): tl.fit_resample(X, Y)
def test_sensitivity_specificity_unused_pos_label(): # but average != 'binary'; even if data is binary with warns(UserWarning, "use labels=\[pos_label\] to specify a single"): sensitivity_specificity_support([1, 2, 1], [1, 2, 2], pos_label=2, average='macro')
def test_deprecate_parameter(): with warns(DeprecationWarning, match="is deprecated from"): deprecate_parameter(Sampler(), '0.2', 'a') with warns(DeprecationWarning, match="Use 'b' instead."): deprecate_parameter(Sampler(), '0.2', 'a', 'b')