def test_chain_shuffle(self): n_samples = 4 repeat = 2 v = np.arange(0, n_samples, 1) data_it = chunk_it(v, chunk_size=2) def chunk_fn(x): return chunk_it(x, chunk_size=2) # first chain is normal, second is shuffled from the two repetitions data_it = repeat_apply(chunk_fn, v, repeat) data_it = chain_it(data_it, shuffle_it(repeat_apply(chunk_fn, v, repeat), buffer_size=8)) data = list(data_it) unique_data = np.unique(data) counts = np.unique(np.bincount(data)) self.assertEqual(len(unique_data), 4) self.assertEqual(len(counts), 1) self.assertEqual(counts[0], 4)
def test_repeat_fn_exhaust(self): n_samples = 4 repeat = 2 v = np.random.uniform(0, 1, [n_samples, 1]) data_it = chunk_it(v, chunk_size=2) def it_fn(x): return iter(x) # data it will get exhausted so it will not repeat data_it = repeat_apply(it_fn, data_it, repeat) # only return 4 items self.assertEqual(len(list(data_it)), n_samples)
def data_pipeline(data, epochs=1, batch_size=args.batch_size, shuffle=False): def chunk_fn(x): return chunk_it(x, chunk_size=batch_size * 1000) if epochs > 1: data = repeat_apply(chunk_fn, data, epochs) else: data = chunk_fn(data) if shuffle: data = shuffle_it(data, args.shuffle_buffer_size) data = batch_it(data, size=batch_size, padding=False) return data
def test_reat_chunk_it(self): n_samples = 4 repeat = 2 v = np.random.uniform(0, 1, [n_samples, 1]) data_it = chunk_it(v, chunk_size=2) def chunk_fn(x): return chunk_it(x, chunk_size=2) # for chunk in data_it: # print(chunk) # print(data_it) data_it = repeat_apply(chunk_fn, v, repeat) self.assertEqual(len(list(data_it)), n_samples * repeat)
def data_pipeline(hdf5_dataset, epochs=1, batch_size=args.batch_size, shuffle=args.shuffle): def chunk_fn(x): return chunk_it(x, chunk_size=batch_size * 1000) if epochs > 1: dataset = repeat_apply(chunk_fn, hdf5_dataset, epochs) else: dataset = chunk_fn(hdf5_dataset) if shuffle: dataset = shuffle_it(dataset, args.shuffle_buffer_size) # cannot pad because 0 might be a valid index and that screws our evaluation # padding = np.zeros([args.ngram_size], dtype=np.int64) # dataset = batch_it(dataset, size=batch_size, padding=True, padding_elem=padding) dataset = batch_it(dataset, size=batch_size, padding=False) return dataset