def test_minibatch_index_generator(): n_samples = 48 n_epochs = 1.5 minibatch_size = 5 data = np.arange(n_samples) expected_total_samples = int(len(data)*n_epochs) for slice_when_possible in (True, False): i = 0 for ix in minibatch_index_generator(n_samples = n_samples, n_epochs=n_epochs, final_treatment='truncate', slice_when_possible = slice_when_possible, minibatch_size=minibatch_size): assert np.array_equal(data[ix], np.arange(i, min(expected_total_samples, i+minibatch_size)) % n_samples) i += len(data[ix]) assert i == expected_total_samples == 72 i = 0 for ix in minibatch_index_generator(n_samples = n_samples, n_epochs=n_epochs, final_treatment='stop', slice_when_possible = slice_when_possible, minibatch_size=minibatch_size): assert np.array_equal(data[ix], np.arange(i, min(expected_total_samples, i+minibatch_size)) % n_samples) i += len(data[ix]) assert i == int(expected_total_samples/minibatch_size) * minibatch_size == 70
def test_minibatch_index_even(): n_samples = 5 n_epochs=2 ixs = list(minibatch_index_generator(n_samples=n_samples, n_epochs=n_epochs, minibatch_size=1, slice_when_possible=False, final_treatment='truncate')) assert ixs==[[i%n_samples] for i in range(n_samples*n_epochs)]