Esempio n. 1
0
    def test_chain_shuffle(self):
        n_samples = 4
        repeat = 2
        v = np.arange(0, n_samples, 1)
        data_it = chunk_it(v, chunk_size=2)

        def chunk_fn(x): return chunk_it(x, chunk_size=2)

        # first chain is normal, second is shuffled from the two repetitions
        data_it = repeat_apply(chunk_fn, v, repeat)

        data_it = chain_it(data_it, shuffle_it(repeat_apply(chunk_fn, v, repeat), buffer_size=8))

        data = list(data_it)

        unique_data = np.unique(data)
        counts = np.unique(np.bincount(data))

        self.assertEqual(len(unique_data), 4)
        self.assertEqual(len(counts), 1)
        self.assertEqual(counts[0], 4)
Esempio n. 2
0
    def test_repeat_fn_exhaust(self):
        n_samples = 4
        repeat = 2
        v = np.random.uniform(0, 1, [n_samples, 1])
        data_it = chunk_it(v, chunk_size=2)

        def it_fn(x): return iter(x)

        # data it will get exhausted so it will not repeat
        data_it = repeat_apply(it_fn, data_it, repeat)

        # only return 4 items
        self.assertEqual(len(list(data_it)), n_samples)
Esempio n. 3
0
def data_pipeline(data, epochs=1, batch_size=args.batch_size, shuffle=False):
    def chunk_fn(x):
        return chunk_it(x, chunk_size=batch_size * 1000)

    if epochs > 1:
        data = repeat_apply(chunk_fn, data, epochs)
    else:
        data = chunk_fn(data)

    if shuffle:
        data = shuffle_it(data, args.shuffle_buffer_size)

    data = batch_it(data, size=batch_size, padding=False)
    return data
Esempio n. 4
0
    def test_reat_chunk_it(self):
        n_samples = 4
        repeat = 2
        v = np.random.uniform(0, 1, [n_samples, 1])
        data_it = chunk_it(v, chunk_size=2)

        def chunk_fn(x): return chunk_it(x, chunk_size=2)

        # for chunk in data_it:
        #    print(chunk)
        # print(data_it)
        data_it = repeat_apply(chunk_fn, v, repeat)

        self.assertEqual(len(list(data_it)), n_samples * repeat)
Esempio n. 5
0
def data_pipeline(hdf5_dataset, epochs=1, batch_size=args.batch_size, shuffle=args.shuffle):
    def chunk_fn(x):
        return chunk_it(x, chunk_size=batch_size * 1000)

    if epochs > 1:
        dataset = repeat_apply(chunk_fn, hdf5_dataset, epochs)
    else:
        dataset = chunk_fn(hdf5_dataset)

    if shuffle:
        dataset = shuffle_it(dataset, args.shuffle_buffer_size)

    # cannot pad because 0 might be a valid index and that screws our evaluation
    # padding = np.zeros([args.ngram_size], dtype=np.int64)
    # dataset = batch_it(dataset, size=batch_size, padding=True, padding_elem=padding)
    dataset = batch_it(dataset, size=batch_size, padding=False)
    return dataset