def test_minibatch_idx_iterator(): ''' Testing the minibatches of idxs are made correctly minibatch_idx_iterator(Y, minibatch_size, randomise, balanced) ''' # simplest case - classes are already balanced and minibatch_size == len(Y) Y = [0, 0, 1, 1, 2, 2] ans = list(mbg.minibatch_idx_iterator( Y, 6, randomise=False, balanced=False)) # ans is a list of lists, the first of which should contain everything assert len(ans) == 1 assert set(ans[0]) == {0, 1, 2, 3, 4, 5} # now with randomisation ans = list(mbg.minibatch_idx_iterator(Y, 6, randomise=True, balanced=False)) assert len(ans) == 1 assert set(ans[0]) == {0, 1, 2, 3, 4, 5}
def test_imbalanced_minibatch_idx_iterator(): # now with imbalanced classes Y = [0, 1, 1, 2, 2, 2, 3, 3, 3, 3] ans = list(mbg.minibatch_idx_iterator( Y, 10, randomise=False, balanced=False)) assert len(ans) == 1 assert set(ans[0]) == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} ans = list(mbg.minibatch_idx_iterator( Y, 10, randomise=False, balanced=True)) # simple minibatch check check_minibatches(ans, 10) # classes should be approximately balanced for an in ans: class_counts = np.bincount([Y[xx] for xx in an]) assert class_counts.max() - class_counts.min() <= 1 # all idxs should be returned at least once all_idxs = [yy for xx in ans for yy in xx] assert set(all_idxs) == {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}