def test_balanced_idxs(): ''' Testing for when the classes are balanced - we should just see one idx for each index position returned ''' Y = [0, 0, 1, 1, 2, 2] ans = list(mbg.balanced_idxs_iterator(Y)) assert len(Y) == len(ans) assert set(ans) == {0, 1, 2, 3, 4, 5}
def test_imbalanced_idxs(): ''' Testing for when the class labels are imbalanced ''' Y = [0, 1, 1, 2, 2, 3, 3, 3, 3, 3] ans = list(mbg.balanced_idxs_iterator(Y)) # we should see each idx at least once assert set(ans) == set(range(len(Y))) # each class should be seen the same number of times provided_classes_counts = np.bincount([Y[idx] for idx in ans]) assert provided_classes_counts.min() == provided_classes_counts.max() # the total number of items returned should be equal to the size of the # largest class times the number of classes assert len(ans) == np.bincount(Y).max() * np.unique(Y).shape[0]