def test_minibatch_idx_iterator():
    '''
    Testing the minibatches of idxs are made correctly
    minibatch_idx_iterator(Y, minibatch_size, randomise, balanced)
    '''

    # simplest case - classes are already balanced and minibatch_size == len(Y)
    Y = [0, 0, 1, 1, 2, 2]
    ans = list(mbg.minibatch_idx_iterator(
        Y, 6, randomise=False, balanced=False))

    # ans is a list of lists, the first of which should contain everything
    assert len(ans) == 1
    assert set(ans[0]) == {0, 1, 2, 3, 4, 5}

    # now with randomisation
    ans = list(mbg.minibatch_idx_iterator(Y, 6, randomise=True, balanced=False))
    assert len(ans) == 1
    assert set(ans[0]) == {0, 1, 2, 3, 4, 5}
def test_imbalanced_minibatch_idx_iterator():
    # now with imbalanced classes
    Y = [0, 1, 1, 2, 2, 2, 3, 3, 3, 3]
    ans = list(mbg.minibatch_idx_iterator(
        Y, 10, randomise=False, balanced=False))
    assert len(ans) == 1
    assert set(ans[0]) == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}


    ans = list(mbg.minibatch_idx_iterator(
        Y, 10, randomise=False, balanced=True))

    # simple minibatch check
    check_minibatches(ans, 10)

    # classes should be approximately balanced
    for an in ans:
        class_counts = np.bincount([Y[xx] for xx in an])
        assert class_counts.max() - class_counts.min() <= 1

    # all idxs should be returned at least once
    all_idxs = [yy for xx in ans for yy in xx]
    assert set(all_idxs) == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}