def test_balanced_idxs():
    '''
    Testing for when the classes are balanced - we should just see one idx
    for each index position returned
    '''
    Y = [0, 0, 1, 1, 2, 2]
    ans = list(mbg.balanced_idxs_iterator(Y))
    assert len(Y) == len(ans)
    assert set(ans) == {0, 1, 2, 3, 4, 5}
def test_imbalanced_idxs():
    '''
    Testing for when the class labels are imbalanced
    '''
    Y = [0, 1, 1, 2, 2, 3, 3, 3, 3, 3]
    ans = list(mbg.balanced_idxs_iterator(Y))

    # we should see each idx at least once
    assert set(ans) == set(range(len(Y)))

    # each class should be seen the same number of times
    provided_classes_counts = np.bincount([Y[idx] for idx in ans])
    assert provided_classes_counts.min() == provided_classes_counts.max()

    # the total number of items returned should be equal to the size of the
    # largest class times the number of classes
    assert len(ans) == np.bincount(Y).max() * np.unique(Y).shape[0]