Esempio n. 1
0
def test_next_batch_drop_last_false_1():
    """ When 'drop_last' is False 'next_batch' should cycle through index. """
    dsi = DatasetIndex(5)
    left = []
    right = list(np.concatenate([dsi.index, dsi.index]))
    for length in [3, 3, 4]:
        batch = dsi.next_batch(batch_size=length, n_epochs=2, drop_last=False)
        left.extend(list(batch.index))
    assert left == right
Esempio n. 2
0
def test_next_batch_drop_last_false_2():
    """ When 'drop_last' is False last batch of last epoch can have smaller length. """
    dsi = DatasetIndex(5)
    left = []
    right = [2] * 7 + [
        1
    ]  # first seven batches have length of 2, last contains one item
    for _ in range(8):
        batch = dsi.next_batch(batch_size=2, n_epochs=3, drop_last=False)
        left.append(len(batch))
    assert left == right
Esempio n. 3
0
def test_split_correctness():
    """ Each element of 'index' is used.
    Constants in 'shares' are such that test does not raise errors.
    """
    dsi = DatasetIndex(5)
    shares = .3 - np.random.random(3) * .05
    dsi.split(shares=shares)

    assert set(dsi.index) == (set(dsi.train.index)
                              | set(dsi.test.index)
                              | set(dsi.validation.index))
Esempio n. 4
0
def test_next_batch_drop_last_true_2():
    """ Order and contents of generated batches may differ at different epochs.
    'shuffle' is True, so dropped indices are different at every epoch.
    """
    dsi = DatasetIndex(5)
    left = set()
    right = set(dsi.index)
    for _ in range(10):
        batch = dsi.next_batch(batch_size=2,
                               n_epochs=None,
                               drop_last=True,
                               shuffle=True)
        left = left | set(batch.index)
    assert left == right
Esempio n. 5
0
def test_next_batch_drop_last_true():
    """ Order and contents of generated batches is same at every epoch.
    'shuffle' is False, so dropped indices are always the same.
    """
    dsi = DatasetIndex(5)
    for _ in range(10):
        batch_1 = dsi.next_batch(batch_size=2,
                                 n_epochs=None,
                                 drop_last=True,
                                 shuffle=False)
        batch_2 = dsi.next_batch(batch_size=2,
                                 n_epochs=None,
                                 drop_last=True,
                                 shuffle=False)
        assert (batch_1.index == dsi.index[:2]).all()
        assert (batch_2.index == dsi.index[2:4]).all()
Esempio n. 6
0
def test_calc_split_raise():
    dsi = DatasetIndex(5)
    with pytest.raises(ValueError):
        dsi.calc_split(shares=[0.5, 0.5, 0.5])
    with pytest.raises(ValueError):
        dsi.calc_split(shares=[0.5, 0.5, 0.5, 0.5])
    with pytest.raises(ValueError):
        DatasetIndex(2).calc_split(shares=[0.5, 0.5, 0.5])