Ejemplo n.º 1
0
def test_next_batch_drop_last_true():
    """ Order and contents of generated batches is same at every epoch.
    'shuffle' is False, so dropped indices are always the same.
    """
    dsi = DatasetIndex(5)
    for _ in range(10):
        batch_1 = dsi.next_batch(batch_size=2,
                                 n_epochs=None,
                                 drop_last=True,
                                 shuffle=False)
        batch_2 = dsi.next_batch(batch_size=2,
                                 n_epochs=None,
                                 drop_last=True,
                                 shuffle=False)
        assert (batch_1.index == dsi.index[:2]).all()
        assert (batch_2.index == dsi.index[2:4]).all()
Ejemplo n.º 2
0
def test_next_batch_bigger():
    """ When 'batch_size' is bigger than length of DatasetIndex, the
    behavior is unstable.
    """
    dsi = DatasetIndex(5)
    with pytest.raises(AssertionError):
        for _ in range(10):
            batch = dsi.next_batch(batch_size=7, n_epochs=None, drop_last=True)
            assert len(batch) == 7
Ejemplo n.º 3
0
def test_next_batch_drop_last_false_1():
    """ When 'drop_last' is False 'next_batch' should cycle through index. """
    dsi = DatasetIndex(5)
    left = []
    right = list(np.concatenate([dsi.index, dsi.index]))
    for length in [3, 3, 4]:
        batch = dsi.next_batch(batch_size=length, n_epochs=2, drop_last=False)
        left.extend(list(batch.index))
    assert left == right
Ejemplo n.º 4
0
def test_next_batch_drop_last_false_2():
    """ When 'drop_last' is False last batch of last epoch can have smaller length. """
    dsi = DatasetIndex(5)
    left = []
    right = [2] * 7 + [
        1
    ]  # first seven batches have length of 2, last contains one item
    for _ in range(8):
        batch = dsi.next_batch(batch_size=2, n_epochs=3, drop_last=False)
        left.append(len(batch))
    assert left == right
Ejemplo n.º 5
0
def test_next_batch_drop_last_true_2():
    """ Order and contents of generated batches may differ at different epochs.
    'shuffle' is True, so dropped indices are different at every epoch.
    """
    dsi = DatasetIndex(5)
    left = set()
    right = set(dsi.index)
    for _ in range(10):
        batch = dsi.next_batch(batch_size=2,
                               n_epochs=None,
                               drop_last=True,
                               shuffle=True)
        left = left | set(batch.index)
    assert left == right
Ejemplo n.º 6
0
def test_next_batch_smaller():
    """ 'batch_size' is twice as small as length DatasetIndex. """
    dsi = DatasetIndex(5)
    for _ in range(10):
        batch = dsi.next_batch(batch_size=2, n_epochs=None, drop_last=True)
        assert len(batch) == 2
Ejemplo n.º 7
0
def test_next_batch_stopiter_pass():
    """ When 'n_epochs' is None it is possible to iterate infinitely. """
    dsi = DatasetIndex(5)
    for _ in range(10):
        dsi.next_batch(1, n_epochs=None)
Ejemplo n.º 8
0
def test_next_batch_stopiter_raise():
    """ Iteration is blocked after end of DatasetIndex. """
    dsi = DatasetIndex(5)
    dsi.next_batch(5, n_epochs=1)
    with pytest.raises(StopIteration):
        dsi.next_batch(5, n_epochs=1)