예제 #1
0
def test_check_sampling_strategy_error_wrong_string(sampling_strategy,
                                                    sampling_type, err_msg):
    with pytest.raises(ValueError,
                       match=("'{}' cannot be used with {}".format(
                           sampling_strategy, err_msg))):
        check_sampling_strategy(sampling_strategy, np.array([1, 2, 3]),
                                sampling_type)
예제 #2
0
def test_check_sampling_strategy_error_wrong_string(sampling_strategy,
                                                    sampling_type, err_msg):
    with pytest.raises(
            ValueError,
            match=("'{}' cannot be used with {}".format(
                sampling_strategy, err_msg))):
        check_sampling_strategy(sampling_strategy,
                                np.array([1, 2, 3]), sampling_type)
예제 #3
0
def test_sampling_strategy_dict_over_sampling():
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    sampling_strategy = {1: 70, 2: 140, 3: 70}
    expected_msg = (r"After over-sampling, the number of samples \(140\) in"
                    r" class 2 will be larger than the number of samples in"
                    r" the majority class \(class #2 -> 100\)")
    with warns(UserWarning, expected_msg):
        check_sampling_strategy(sampling_strategy, y, 'over-sampling')
예제 #4
0
def test_sampling_strategy_dict_over_sampling():
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    sampling_strategy = {1: 70, 2: 140, 3: 70}
    expected_msg = ("After over-sampling, the number of samples \(140\) in"
                    " class 2 will be larger than the number of samples in the"
                    " majority class \(class #2 -> 100\)")
    with warns(UserWarning, expected_msg):
        check_sampling_strategy(sampling_strategy, y, 'over-sampling')
예제 #5
0
def test_check_sampling_strategy_warning():
    msg = 'dict for cleaning methods is deprecated'
    with pytest.warns(DeprecationWarning, match=msg):
        check_sampling_strategy({
            1: 0,
            2: 0,
            3: 0
        }, multiclass_target, 'clean-sampling')
예제 #6
0
def test_check_sampling_strategy_warning():
    msg = 'dict for cleaning methods is deprecated'
    with pytest.warns(DeprecationWarning, match=msg):
        check_sampling_strategy({
            1: 0,
            2: 0,
            3: 0
        }, multiclass_target, 'clean-sampling')
예제 #7
0
def test_check_sampling_strategy_warning():
    msg = "dict for cleaning methods is not supported"
    with pytest.raises(ValueError, match=msg):
        check_sampling_strategy({
            1: 0,
            2: 0,
            3: 0
        }, multiclass_target, "clean-sampling")
예제 #8
0
def test_check_sampling_strategy(
    sampling_strategy, sampling_type, expected_sampling_strategy, target
):
    sampling_strategy_ = check_sampling_strategy(
        sampling_strategy, target, sampling_type
    )
    assert sampling_strategy_ == expected_sampling_strategy
예제 #9
0
def test_sampling_strategy_check_order(sampling_strategy, sampling_type,
                                       expected_result):
    # We pass on purpose a non sorted dictionary and check that the resulting
    # dictionary is sorted. Refer to issue #428.
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    sampling_strategy_ = check_sampling_strategy(
        sampling_strategy, y, sampling_type)
    assert sampling_strategy_ == expected_result
def test_sampling_strategy_check_order(
    sampling_strategy, sampling_type, expected_result
):
    # We pass on purpose a non sorted dictionary and check that the resulting
    # dictionary is sorted. Refer to issue #428.
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    sampling_strategy_ = check_sampling_strategy(sampling_strategy, y, sampling_type)
    assert sampling_strategy_ == expected_result
예제 #11
0
 def _check_sampling_strategy(self, y):
     """Check sampling strategy."""
     self.sampling_strategy_ = check_sampling_strategy(
         self.oversampler_.sampling_strategy,
         y,
         self._sampling_type,
     )
     return self
예제 #12
0
    def _fit(self, X, y=None):
        self.sampling_strategy_ = check_sampling_strategy(
            self.sampling_strategy, y, self._sampling_type)

        X_cont = None
        X_cat = None
        if self.idx_cont is not None:
            X_cont = X[:, self.idx_cont]
        if self.idx_cat is not None:
            X_cat = X[:, self.idx_cat]

        # Allow training only on single class
        group_filter = None
        if isinstance(self.auxiliary, list):
            group_filter = self.auxiliary

        dataset = TabularDataset(X=X_cont,
                                 X_cat=X_cat,
                                 y=y,
                                 cat_levels=self.cat_levels,
                                 group_filter=group_filter)
        self.no_aux = dataset.no_aux
        if self.auxiliary is False:
            self.no_aux = 0

        generator, critic, trainer = make_GANbalancer(
            dataset=dataset,
            generator_input=self.generator_input,
            generator_layers=self.generator_layers,
            critic_layers=self.critic_layers,
            emb_sizes=self.emb_sizes,
            no_aux=self.no_aux,
            learning_rate=self.learning_rate,
            critic_iterations=self.critic_iterations)

        train_loader = DataLoader(dataset,
                                  batch_size=self.batch_size,
                                  shuffle=True)

        # Train for generator update iterations instead of epochs, because this is
        # clearer to specify w.r.t to batch size
        if self.verbose > 0: pbar = tqdm(total=self.n_iter)

        while generator.training_iterations < self.n_iter:
            temp_iterations = generator.training_iterations
            trainer._train_epoch(train_loader)

            if self.verbose > 0:
                pbar.update(generator.training_iterations - temp_iterations)

        if self.verbose > 0: pbar.close()

        self.generator = generator
        return self
예제 #13
0
def test_sampling_strategy_callable_args():
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    multiplier = {1: 1.5, 2: 1, 3: 3}

    def sampling_strategy_func(y, multiplier):
        """samples such that each class will be affected by the multiplier."""
        target_stats = Counter(y)
        return {
            key: int(values * multiplier[key])
            for key, values in target_stats.items()
        }

    sampling_strategy_ = check_sampling_strategy(
        sampling_strategy_func, y, 'over-sampling', multiplier=multiplier)
    assert sampling_strategy_ == {1: 25, 2: 0, 3: 50}
def test_sampling_strategy_callable_args():
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    multiplier = {1: 1.5, 2: 1, 3: 3}

    def sampling_strategy_func(y, multiplier):
        """samples such that each class will be affected by the multiplier."""
        target_stats = Counter(y)
        return {
            key: int(values * multiplier[key]) for key, values in target_stats.items()
        }

    sampling_strategy_ = check_sampling_strategy(
        sampling_strategy_func, y, "over-sampling", multiplier=multiplier
    )
    assert sampling_strategy_ == {1: 25, 2: 0, 3: 50}
예제 #15
0
def test_check_sampling_strategy_error():
    with pytest.raises(ValueError, match="'sampling_type' should be one of"):
        check_sampling_strategy('auto', np.array([1, 2, 3]), 'rnd')

    error_regex = "The target 'y' needs to have more than 1 class."
    with pytest.raises(ValueError, match=error_regex):
        check_sampling_strategy('auto', np.ones((10, )), 'over-sampling')

    error_regex = "When 'sampling_strategy' is a string, it needs to be one of"
    with pytest.raises(ValueError, match=error_regex):
        check_sampling_strategy('rnd', np.array([1, 2, 3]), 'over-sampling')
예제 #16
0
def test_check_sampling_strategy_error():
    with pytest.raises(ValueError, match="'sampling_type' should be one of"):
        check_sampling_strategy("auto", np.array([1, 2, 3]), "rnd")

    error_regex = "The target 'y' needs to have more than 1 class."
    with pytest.raises(ValueError, match=error_regex):
        check_sampling_strategy("auto", np.ones((10,)), "over-sampling")

    error_regex = "When 'sampling_strategy' is a string, it needs to be one of"
    with pytest.raises(ValueError, match=error_regex):
        check_sampling_strategy("rnd", np.array([1, 2, 3]), "over-sampling")
예제 #17
0
    def _fit_resample(self, X, y):
        self.sampling_strategy_ = check_sampling_strategy(
            self.sampling_strategy, y, self._sampling_type)

        X_resampled = X.copy()
        y_resampled = y.copy()

        #random_state = check_random_state(self.random_state)
        #target_stats = Counter(y)

        self._fit(X, y)

        X_new, y_new = self._sample(X, y, random_state=self.random_state)

        X_resampled = np.vstack((X_resampled, X_new))
        y_resampled = np.hstack((y_resampled, y_new))

        return X_resampled, y_resampled
예제 #18
0
def test_sampling_strategy_dict_error():
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    sampling_strategy = {1: -100, 2: 50, 3: 25}
    with pytest.raises(ValueError, match="in a class cannot be negative."):
        check_sampling_strategy(sampling_strategy, y, 'under-sampling')
    sampling_strategy = {1: 45, 2: 100, 3: 70}
    error_regex = ("With over-sampling methods, the number of samples in a"
                   " class should be greater or equal to the original number"
                   " of samples. Originally, there is 50 samples and 45"
                   " samples are asked.")
    with pytest.raises(ValueError, match=error_regex):
        check_sampling_strategy(sampling_strategy, y, 'over-sampling')

    error_regex = ("With under-sampling methods, the number of samples in a"
                   " class should be less or equal to the original number of"
                   " samples. Originally, there is 25 samples and 70 samples"
                   " are asked.")
    with pytest.raises(ValueError, match=error_regex):
        check_sampling_strategy(sampling_strategy, y, 'under-sampling')
예제 #19
0
def test_sampling_strategy_dict_error():
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    sampling_strategy = {1: -100, 2: 50, 3: 25}
    with pytest.raises(ValueError, match="in a class cannot be negative."):
        check_sampling_strategy(sampling_strategy, y, 'under-sampling')
    sampling_strategy = {1: 45, 2: 100, 3: 70}
    error_regex = ("With over-sampling methods, the number of samples in a"
                   " class should be greater or equal to the original number"
                   " of samples. Originally, there is 50 samples and 45"
                   " samples are asked.")
    with pytest.raises(ValueError, match=error_regex):
        check_sampling_strategy(sampling_strategy, y, 'over-sampling')

    error_regex = ("With under-sampling methods, the number of samples in a"
                   " class should be less or equal to the original number of"
                   " samples. Originally, there is 25 samples and 70 samples"
                   " are asked.")
    with pytest.raises(ValueError, match=error_regex):
        check_sampling_strategy(sampling_strategy, y, 'under-sampling')
예제 #20
0
def test_sampling_strategy_float_error_not_binary():
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    with pytest.raises(ValueError, match="the type of target is binary"):
        sampling_strategy = 0.5
        check_sampling_strategy(sampling_strategy, y, "under-sampling")
예제 #21
0
def test_sampling_strategy_list_error_not_clean_sampling(sampling_method):
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    with pytest.raises(ValueError, match="cannot be a list for samplers"):
        sampling_strategy = [1, 2, 3]
        check_sampling_strategy(sampling_strategy, y, sampling_method)
예제 #22
0
def test_sampling_strategy_class_target_unknown(
    sampling_strategy, sampling_method
):
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    with pytest.raises(ValueError, match="are not present in the data."):
        check_sampling_strategy(sampling_strategy, y, sampling_method)
예제 #23
0
def test_sampling_strategy_float_error_not_in_range(sampling_strategy):
    y = np.array([1] * 50 + [2] * 100)
    with pytest.raises(ValueError, match="it should be in the range"):
        check_sampling_strategy(sampling_strategy, y, "under-sampling")
예제 #24
0
def test_check_sampling_strategy_float_error(ratio, y, type, err_msg):
    with pytest.raises(ValueError, match=err_msg):
        check_sampling_strategy(ratio, y, type)
예제 #25
0
def test_sampling_strategy_float_error_not_in_range(sampling_strategy):
    y = np.array([1] * 50 + [2] * 100)
    with pytest.raises(ValueError, match='it should be in the range'):
        check_sampling_strategy(sampling_strategy, y, 'under-sampling')
def test_sampling_strategy_dict_over_sampling():
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    sampling_strategy = {1: 70, 2: 140, 3: 70}
    expected_msg = "After over-sampling, the number of samples "
    with pytest.warns(UserWarning, match=expected_msg):
        check_sampling_strategy(sampling_strategy, y, "over-sampling")
예제 #27
0
def test_sampling_strategy_float_error_not_binary():
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    with pytest.raises(ValueError, match='the type of target is binary'):
        sampling_strategy = 0.5
        check_sampling_strategy(sampling_strategy, y, 'under-sampling')
예제 #28
0
def test_sampling_strategy_class_target_unknown(sampling_strategy,
                                                sampling_method):
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    with pytest.raises(ValueError, match="are not present in the data."):
        check_sampling_strategy(sampling_strategy, y, sampling_method)
예제 #29
0
def test_check_sampling_strategy_float_error():
    msg = "'clean-sampling' methods do let the user specify the sampling ratio"
    with pytest.raises(ValueError, match=msg):
        check_sampling_strategy(0.5, binary_target, 'clean-sampling')
예제 #30
0
def test_check_sampling_strategy_float_error():
    msg = "'clean-sampling' methods do let the user specify the sampling ratio"
    with pytest.raises(ValueError, match=msg):
        check_sampling_strategy(0.5, binary_target, 'clean-sampling')
예제 #31
0
def test_check_sampling_strategy(sampling_strategy, sampling_type,
                                 expected_sampling_strategy, target):
    sampling_strategy_ = check_sampling_strategy(sampling_strategy, target,
                                                 sampling_type)
    assert sampling_strategy_ == expected_sampling_strategy
예제 #32
0
    def fit_resample(self, X, X_3d, y, y_org):
        self.sampling_strategy_ = check_sampling_strategy(
            self.sampling_strategy, y, 'over-sampling')
        self._validate_estimator()
        random_state = check_random_state(self.random_state)

        X_resampled = [X.copy()]
        X_3d_resampled = [X_3d.copy()]
        y_resampled = [y.copy()]
        y_org_resampled = [y_org.copy()]

        for class_sample, n_samples in self.sampling_strategy_.items():
            if n_samples == 0:
                continue
            target_class_indices = np.flatnonzero(y == class_sample)
            X_class = _safe_indexing(X, target_class_indices)
            X_class_3d = _safe_indexing(X_3d, target_class_indices)
            y_class_org = _safe_indexing(y_org, target_class_indices)

            # self.nn_.set_params(**{"n_neighbors": self.n_neighbors})
            self.nn_.fit(X[:, self.variables])
            nns = self.nn_.kneighbors(X_class[:, self.variables],
                                      return_distance=False)[:, 1:]
            # The ratio is computed using a one-vs-rest manner. Using majority
            # in multi-class would lead to slightly different results at the
            # cost of introducing a new parameter.
            n_neighbors = self.nn_.n_neighbors - 1
            ratio_nn = np.sum(y[nns] != class_sample, axis=1) / n_neighbors
            if not np.sum(ratio_nn):
                raise RuntimeError("Not any neigbours belong to the majority"
                                   " class. This case will induce a NaN case"
                                   " with a division by zero. ADASYN is not"
                                   " suited for this specific dataset."
                                   " Use SMOTE instead.")
            ratio_nn /= np.sum(ratio_nn)
            n_samples_generate = np.rint(ratio_nn * n_samples).astype(int)
            # rounding may cause new amount for n_samples
            n_samples = np.sum(n_samples_generate)
            if not n_samples:
                raise ValueError("No samples will be generated with the"
                                 " provided ratio settings.")

            # the nearest neighbors need to be fitted only on the current class
            # to find the class NN to generate new samples
            # self.nn_.set_params(**{"n_neighbors": np.minimum(int(X_class.shape[0]-1), self.n_neighbors)})
            self.nn_.fit(X_class[:, self.variables])
            nns = self.nn_.kneighbors(X_class[:, self.variables],
                                      return_distance=False)[:, 1:]

            enumerated_class_indices = np.arange(len(target_class_indices))
            rows = np.repeat(enumerated_class_indices, n_samples_generate)
            cols = random_state.choice(n_neighbors, size=n_samples)
            diffs = X_class[nns[
                rows, cols]][:, self.variables] - X_class[rows][:,
                                                                self.variables]
            diffs_3d = X_class_3d[nns[
                rows, cols]][:, self.variables_3d, :] - X_class_3d[
                    rows][:, self.variables_3d, :]
            steps = random_state.uniform(size=(n_samples, 1))
            X_new = X_class[rows]
            X_new_3d = X_class_3d[rows]
            y_new_org = y_class_org[rows]

            if sparse.issparse(X):
                sparse_func = type(X).__name__
                steps = getattr(sparse, sparse_func)(steps)
                X_new[:, self.variables] = X_class[
                    rows][:, self.variables] + steps.multiply(diffs)
                X_new_3d[:, self.variables_3d, :] = X_class_3d[
                    rows][:, self.
                          variables_3d, :] + steps[:, :,
                                                   np.newaxis].multiply(diffs)
            else:
                X_new[:, self.variables] = X_class[
                    rows][:, self.variables] + steps * diffs
                X_new_3d[:, self.variables_3d, :] = X_class_3d[
                    rows][:, self.
                          variables_3d, :] + steps[:, :, np.newaxis] * diffs_3d

            X_new = X_new.astype(X.dtype)
            X_new_3d = X_new_3d.astype(X.dtype)
            y_new = np.full(n_samples, fill_value=class_sample, dtype=y.dtype)
            X_resampled.append(X_new)
            X_3d_resampled.append(X_new_3d)
            y_resampled.append(y_new)
            y_org_resampled.append(y_new_org)

        if sparse.issparse(X):
            X_resampled = sparse.vstack(X_resampled, format=X.format)
            X_3d_resampled = sparse.vstack(X_3d_resampled, format=X.format)
        else:
            X_resampled = np.vstack(X_resampled)
            X_3d_resampled = np.vstack(X_3d_resampled)
        y_resampled = np.hstack(y_resampled)
        y_org_resampled = np.hstack(y_org_resampled)

        return X_resampled, X_3d_resampled, y_org_resampled
예제 #33
0
def test_sampling_strategy_list_error_not_clean_sampling(sampling_method):
    y = np.array([1] * 50 + [2] * 100 + [3] * 25)
    with pytest.raises(ValueError, match='cannot be a list for samplers'):
        sampling_strategy = [1, 2, 3]
        check_sampling_strategy(sampling_strategy, y, sampling_method)
예제 #34
0
def test_check_sampling_strategy_float_error(ratio, y, type, err_msg):
    with pytest.raises(ValueError, match=err_msg):
        check_sampling_strategy(ratio, y, type)