def test_next_batch_drop_last_true(): """ Order and contents of generated batches is same at every epoch. 'shuffle' is False, so dropped indices are always the same. """ dsi = DatasetIndex(5) for _ in range(10): batch_1 = dsi.next_batch(batch_size=2, n_epochs=None, drop_last=True, shuffle=False) batch_2 = dsi.next_batch(batch_size=2, n_epochs=None, drop_last=True, shuffle=False) assert (batch_1.index == dsi.index[:2]).all() assert (batch_2.index == dsi.index[2:4]).all()
def test_next_batch_bigger(): """ When 'batch_size' is bigger than length of DatasetIndex, the behavior is unstable. """ dsi = DatasetIndex(5) with pytest.raises(AssertionError): for _ in range(10): batch = dsi.next_batch(batch_size=7, n_epochs=None, drop_last=True) assert len(batch) == 7
def test_next_batch_drop_last_false_1(): """ When 'drop_last' is False 'next_batch' should cycle through index. """ dsi = DatasetIndex(5) left = [] right = list(np.concatenate([dsi.index, dsi.index])) for length in [3, 3, 4]: batch = dsi.next_batch(batch_size=length, n_epochs=2, drop_last=False) left.extend(list(batch.index)) assert left == right
def test_next_batch_drop_last_false_2(): """ When 'drop_last' is False last batch of last epoch can have smaller length. """ dsi = DatasetIndex(5) left = [] right = [2] * 7 + [ 1 ] # first seven batches have length of 2, last contains one item for _ in range(8): batch = dsi.next_batch(batch_size=2, n_epochs=3, drop_last=False) left.append(len(batch)) assert left == right
def test_next_batch_drop_last_true_2(): """ Order and contents of generated batches may differ at different epochs. 'shuffle' is True, so dropped indices are different at every epoch. """ dsi = DatasetIndex(5) left = set() right = set(dsi.index) for _ in range(10): batch = dsi.next_batch(batch_size=2, n_epochs=None, drop_last=True, shuffle=True) left = left | set(batch.index) assert left == right
def test_next_batch_smaller(): """ 'batch_size' is twice as small as length DatasetIndex. """ dsi = DatasetIndex(5) for _ in range(10): batch = dsi.next_batch(batch_size=2, n_epochs=None, drop_last=True) assert len(batch) == 2
def test_next_batch_stopiter_pass(): """ When 'n_epochs' is None it is possible to iterate infinitely. """ dsi = DatasetIndex(5) for _ in range(10): dsi.next_batch(1, n_epochs=None)
def test_next_batch_stopiter_raise(): """ Iteration is blocked after end of DatasetIndex. """ dsi = DatasetIndex(5) dsi.next_batch(5, n_epochs=1) with pytest.raises(StopIteration): dsi.next_batch(5, n_epochs=1)