def _iter_indices(self, X, y, groups=None): # type: ignore n_samples = _num_samples(X) y = check_array(y, ensure_2d=False, dtype=None) n_train, n_test = _validate_shuffle_split( n_samples, self.test_size, self.train_size, default_test_size=self._default_test_size) if y.ndim == 2: # for multi-label y, map each distinct row to a string repr # using join because str(row) uses an ellipsis if len(row) > 1000 y = np.array([' '.join(row.astype('str')) for row in y]) classes, y_indices = np.unique(y, return_inverse=True) n_classes = classes.shape[0] class_counts = np.bincount(y_indices) # print(class_counts) if n_train < n_classes: raise ValueError('The train_size = %d should be greater or ' 'equal to the number of classes = %d' % (n_train, n_classes)) if n_test < n_classes: raise ValueError('The test_size = %d should be greater or ' 'equal to the number of classes = %d' % (n_test, n_classes)) # Find the sorted list of instances for each class: # (np.unique above performs a sort, so code is O(n logn) already) class_indices = np.split(np.argsort(y_indices, kind='mergesort'), np.cumsum(class_counts)[:-1]) rng = check_random_state(self.random_state) for _ in range(self.n_splits): # if there are ties in the class-counts, we want # to make sure to break them anew in each iteration n_i = _approximate_mode(class_counts, n_train, rng) class_counts_remaining = class_counts - n_i t_i = _approximate_mode(class_counts_remaining, n_test, rng) train = [] test = [] for i in range(n_classes): # print("Before", i, class_counts[i], n_i[i], t_i[i]) permutation = rng.permutation(class_counts[i]) perm_indices_class_i = class_indices[i].take(permutation, mode='clip') if n_i[i] == 0: n_i[i] = 1 t_i[i] = t_i[i] - 1 # print("After", i, class_counts[i], n_i[i], t_i[i]) train.extend(perm_indices_class_i[:n_i[i]]) test.extend(perm_indices_class_i[n_i[i]:n_i[i] + t_i[i]]) train = rng.permutation(train) test = rng.permutation(test) yield train, test
def test_approximate_mode(): """Make sure sklearn.utils._approximate_mode returns valid results for cases where "class_counts * n_draws" is enough to overflow 32-bit signed integer. Non-regression test for: https://github.com/scikit-learn/scikit-learn/issues/20774 """ X = np.array([99000, 1000], dtype=np.int32) ret = _approximate_mode(class_counts=X, n_draws=25000, rng=0) # Draws 25% of the total population, so in this case a fair draw means: # 25% * 99.000 = 24.750 # 25% * 1.000 = 250 assert_array_equal(ret, [24750, 250])