def test_sampling_strategy_dict_over_sampling(): y = np.array([1] * 50 + [2] * 100 + [3] * 25) sampling_strategy = {1: 70, 2: 140, 3: 70} expected_msg = (r"After over-sampling, the number of samples \(140\) in" r" class 2 will be larger than the number of samples in" r" the majority class \(class #2 -> 100\)") with warns(UserWarning, expected_msg): check_sampling_strategy(sampling_strategy, y, "over-sampling")
def test_check_sampling_strategy_warning(): msg = "dict for cleaning methods is not supported" with pytest.raises(ValueError, match=msg): check_sampling_strategy({ 1: 0, 2: 0, 3: 0 }, multiclass_target, "clean-sampling")
def test_check_sampling_strategy_error_wrong_string(sampling_strategy, sampling_type, err_msg): with pytest.raises( ValueError, match=("'{}' cannot be used with {}".format( sampling_strategy, err_msg)), ): check_sampling_strategy(sampling_strategy, np.array([1, 2, 3]), sampling_type)
def test_sampling_strategy_check_order(sampling_strategy, sampling_type, expected_result): # We pass on purpose a non sorted dictionary and check that the resulting # dictionary is sorted. Refer to issue #428. y = np.array([1] * 50 + [2] * 100 + [3] * 25) sampling_strategy_ = check_sampling_strategy(sampling_strategy, y, sampling_type) assert sampling_strategy_ == expected_result
def test_check_sampling_strategy_error(): with pytest.raises(ValueError, match="'sampling_type' should be one of"): check_sampling_strategy("auto", np.array([1, 2, 3]), "rnd") error_regex = "The target 'y' needs to have more than 1 class." with pytest.raises(ValueError, match=error_regex): check_sampling_strategy("auto", np.ones((10, )), "over-sampling") error_regex = "When 'sampling_strategy' is a string, it needs to be one of" with pytest.raises(ValueError, match=error_regex): check_sampling_strategy("rnd", np.array([1, 2, 3]), "over-sampling")
def test_sampling_strategy_callable_args(): y = np.array([1] * 50 + [2] * 100 + [3] * 25) multiplier = {1: 1.5, 2: 1, 3: 3} def sampling_strategy_func(y, multiplier): """samples such that each class will be affected by the multiplier.""" target_stats = Counter(y) return { key: int(values * multiplier[key]) for key, values in target_stats.items() } sampling_strategy_ = check_sampling_strategy(sampling_strategy_func, y, "over-sampling", multiplier=multiplier) assert sampling_strategy_ == {1: 25, 2: 0, 3: 50}
def test_sampling_strategy_dict_error(): y = np.array([1] * 50 + [2] * 100 + [3] * 25) sampling_strategy = {1: -100, 2: 50, 3: 25} with pytest.raises(ValueError, match="in a class cannot be negative."): check_sampling_strategy(sampling_strategy, y, "under-sampling") sampling_strategy = {1: 45, 2: 100, 3: 70} error_regex = ("With over-sampling methods, the number of samples in a" " class should be greater or equal to the original number" " of samples. Originally, there is 50 samples and 45" " samples are asked.") with pytest.raises(ValueError, match=error_regex): check_sampling_strategy(sampling_strategy, y, "over-sampling") error_regex = ("With under-sampling methods, the number of samples in a" " class should be less or equal to the original number of" " samples. Originally, there is 25 samples and 70 samples" " are asked.") with pytest.raises(ValueError, match=error_regex): check_sampling_strategy(sampling_strategy, y, "under-sampling")
def test_check_sampling_strategy(sampling_strategy, sampling_type, expected_sampling_strategy, target): sampling_strategy_ = check_sampling_strategy(sampling_strategy, target, sampling_type) assert sampling_strategy_ == expected_sampling_strategy
def test_sampling_strategy_list_error_not_clean_sampling(sampling_method): y = np.array([1] * 50 + [2] * 100 + [3] * 25) with pytest.raises(ValueError, match="cannot be a list for samplers"): sampling_strategy = [1, 2, 3] check_sampling_strategy(sampling_strategy, y, sampling_method)
def test_sampling_strategy_float_error_not_binary(): y = np.array([1] * 50 + [2] * 100 + [3] * 25) with pytest.raises(ValueError, match="the type of target is binary"): sampling_strategy = 0.5 check_sampling_strategy(sampling_strategy, y, "under-sampling")
def test_sampling_strategy_float_error_not_in_range(sampling_strategy): y = np.array([1] * 50 + [2] * 100) with pytest.raises(ValueError, match="it should be in the range"): check_sampling_strategy(sampling_strategy, y, "under-sampling")
def test_sampling_strategy_class_target_unknown(sampling_strategy, sampling_method): y = np.array([1] * 50 + [2] * 100 + [3] * 25) with pytest.raises(ValueError, match="are not present in the data."): check_sampling_strategy(sampling_strategy, y, sampling_method)
def test_check_sampling_strategy_float_error(ratio, y, type, err_msg): with pytest.raises(ValueError, match=err_msg): check_sampling_strategy(ratio, y, type)