示例#1
0
def test_minibatch_index_generator():

    n_samples = 48
    n_epochs = 1.5
    minibatch_size = 5

    data = np.arange(n_samples)

    expected_total_samples = int(len(data)*n_epochs)

    for slice_when_possible in (True, False):

        i = 0
        for ix in minibatch_index_generator(n_samples = n_samples, n_epochs=n_epochs, final_treatment='truncate',
                slice_when_possible = slice_when_possible, minibatch_size=minibatch_size):
            assert np.array_equal(data[ix], np.arange(i, min(expected_total_samples, i+minibatch_size)) % n_samples)
            i += len(data[ix])
        assert i == expected_total_samples == 72

        i = 0
        for ix in minibatch_index_generator(n_samples = n_samples, n_epochs=n_epochs, final_treatment='stop',
                slice_when_possible = slice_when_possible, minibatch_size=minibatch_size):
            assert np.array_equal(data[ix], np.arange(i, min(expected_total_samples, i+minibatch_size)) % n_samples)
            i += len(data[ix])
        assert i == int(expected_total_samples/minibatch_size) * minibatch_size == 70
示例#2
0
def test_minibatch_index_generator():

    n_samples = 48
    n_epochs = 1.5
    minibatch_size = 5

    data = np.arange(n_samples)

    expected_total_samples = int(len(data)*n_epochs)

    for slice_when_possible in (True, False):

        i = 0
        for ix in minibatch_index_generator(n_samples = n_samples, n_epochs=n_epochs, final_treatment='truncate',
                slice_when_possible = slice_when_possible, minibatch_size=minibatch_size):
            assert np.array_equal(data[ix], np.arange(i, min(expected_total_samples, i+minibatch_size)) % n_samples)
            i += len(data[ix])
        assert i == expected_total_samples == 72

        i = 0
        for ix in minibatch_index_generator(n_samples = n_samples, n_epochs=n_epochs, final_treatment='stop',
                slice_when_possible = slice_when_possible, minibatch_size=minibatch_size):
            assert np.array_equal(data[ix], np.arange(i, min(expected_total_samples, i+minibatch_size)) % n_samples)
            i += len(data[ix])
        assert i == int(expected_total_samples/minibatch_size) * minibatch_size == 70
示例#3
0
def test_minibatch_index_even():

    n_samples = 5
    n_epochs=2

    ixs = list(minibatch_index_generator(n_samples=n_samples, n_epochs=n_epochs, minibatch_size=1, slice_when_possible=False, final_treatment='truncate'))

    assert ixs==[[i%n_samples] for i in range(n_samples*n_epochs)]
示例#4
0
def test_minibatch_index_even():

    n_samples = 5
    n_epochs=2

    ixs = list(minibatch_index_generator(n_samples=n_samples, n_epochs=n_epochs, minibatch_size=1, slice_when_possible=False, final_treatment='truncate'))

    assert ixs==[[i%n_samples] for i in range(n_samples*n_epochs)]