def run_dataset(self, n_iters, n_threads, **kwargs): sampler = ImageWindowDataset(**kwargs) sampler.set_num_threads(n_threads) with self.cached_session() as sess: true_iters = 0 next_element = sampler.pop_batch_op() windows = [] try: for _ in range(min(n_iters, 100)): windows.append(sess.run(next_element)['mr_location']) true_iters = true_iters + 1 except (tf.errors.OutOfRangeError, EOFError): pass assert true_iters <= 100, 'keep the test smaller than 100 iters' return true_iters, np.concatenate(windows, 0)
def test_epoch(self): reader = get_2d_reader() batch_size = 3 sampler = ImageWindowDataset(reader=reader, batch_size=batch_size, epoch=1) with self.cached_session() as sess: next_element = sampler.pop_batch_op() iters = 0 try: for _ in range(400): window = sess.run(next_element) iters = iters + 1 except tf.errors.OutOfRangeError: pass # batch size 3, 40 images in total self.assertEqual( np.ceil(reader.num_subjects / np.float(batch_size)), iters)