def test_sampling_strategy_dict_over_sampling():
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    sampling_strategy = {1: 70, 2: 140, 3: 70}
    expected_msg = (r"After over-sampling, the number of samples \(140\) in"
                    r" class 2 will be larger than the number of samples in"
                    r" the majority class \(class #2 -> 100\)")
    with warns(UserWarning, expected_msg):
        check_sampling_strategy(sampling_strategy, y, "over-sampling")
def test_check_sampling_strategy_warning():
    msg = "dict for cleaning methods is not supported"
    with pytest.raises(ValueError, match=msg):
        check_sampling_strategy({
            1: 0,
            2: 0,
            3: 0
        }, multiclass_target, "clean-sampling")
def test_check_sampling_strategy_error_wrong_string(sampling_strategy,
                                                    sampling_type, err_msg):
    with pytest.raises(
            ValueError,
            match=("'{}' cannot be used with {}".format(
                sampling_strategy, err_msg)),
    ):
        check_sampling_strategy(sampling_strategy, np.array([1, 2, 3]),
                                sampling_type)
def test_sampling_strategy_check_order(sampling_strategy, sampling_type,
                                       expected_result):
    # We pass on purpose a non sorted dictionary and check that the resulting
    # dictionary is sorted. Refer to issue #428.
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    sampling_strategy_ = check_sampling_strategy(sampling_strategy, y,
                                                 sampling_type)
    assert sampling_strategy_ == expected_result
def test_check_sampling_strategy_error():
    with pytest.raises(ValueError, match="'sampling_type' should be one of"):
        check_sampling_strategy("auto", np.array([1, 2, 3]), "rnd")

    error_regex = "The target 'y' needs to have more than 1 class."
    with pytest.raises(ValueError, match=error_regex):
        check_sampling_strategy("auto", np.ones((10, )), "over-sampling")

    error_regex = "When 'sampling_strategy' is a string, it needs to be one of"
    with pytest.raises(ValueError, match=error_regex):
        check_sampling_strategy("rnd", np.array([1, 2, 3]), "over-sampling")
def test_sampling_strategy_callable_args():
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    multiplier = {1: 1.5, 2: 1, 3: 3}

    def sampling_strategy_func(y, multiplier):
        """samples such that each class will be affected by the multiplier."""
        target_stats = Counter(y)
        return {
            key: int(values * multiplier[key])
            for key, values in target_stats.items()
        }

    sampling_strategy_ = check_sampling_strategy(sampling_strategy_func,
                                                 y,
                                                 "over-sampling",
                                                 multiplier=multiplier)
    assert sampling_strategy_ == {1: 25, 2: 0, 3: 50}
def test_sampling_strategy_dict_error():
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    sampling_strategy = {1: -100, 2: 50, 3: 25}
    with pytest.raises(ValueError, match="in a class cannot be negative."):
        check_sampling_strategy(sampling_strategy, y, "under-sampling")
    sampling_strategy = {1: 45, 2: 100, 3: 70}
    error_regex = ("With over-sampling methods, the number of samples in a"
                   " class should be greater or equal to the original number"
                   " of samples. Originally, there is 50 samples and 45"
                   " samples are asked.")
    with pytest.raises(ValueError, match=error_regex):
        check_sampling_strategy(sampling_strategy, y, "over-sampling")

    error_regex = ("With under-sampling methods, the number of samples in a"
                   " class should be less or equal to the original number of"
                   " samples. Originally, there is 25 samples and 70 samples"
                   " are asked.")
    with pytest.raises(ValueError, match=error_regex):
        check_sampling_strategy(sampling_strategy, y, "under-sampling")
def test_check_sampling_strategy(sampling_strategy, sampling_type,
                                 expected_sampling_strategy, target):
    sampling_strategy_ = check_sampling_strategy(sampling_strategy, target,
                                                 sampling_type)
    assert sampling_strategy_ == expected_sampling_strategy
def test_sampling_strategy_list_error_not_clean_sampling(sampling_method):
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    with pytest.raises(ValueError, match="cannot be a list for samplers"):
        sampling_strategy = [1, 2, 3]
        check_sampling_strategy(sampling_strategy, y, sampling_method)
def test_sampling_strategy_float_error_not_binary():
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    with pytest.raises(ValueError, match="the type of target is binary"):
        sampling_strategy = 0.5
        check_sampling_strategy(sampling_strategy, y, "under-sampling")
def test_sampling_strategy_float_error_not_in_range(sampling_strategy):
    y = np.array([1] * 50 + [2] * 100)
    with pytest.raises(ValueError, match="it should be in the range"):
        check_sampling_strategy(sampling_strategy, y, "under-sampling")
def test_sampling_strategy_class_target_unknown(sampling_strategy,
                                                sampling_method):
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    with pytest.raises(ValueError, match="are not present in the data."):
        check_sampling_strategy(sampling_strategy, y, sampling_method)
def test_check_sampling_strategy_float_error(ratio, y, type, err_msg):
    with pytest.raises(ValueError, match=err_msg):
        check_sampling_strategy(ratio, y, type)