示例#1
0
 def run_dataset(self, n_iters, n_threads, **kwargs):
     sampler = ImageWindowDataset(**kwargs)
     sampler.set_num_threads(n_threads)
     with self.cached_session() as sess:
         true_iters = 0
         next_element = sampler.pop_batch_op()
         windows = []
         try:
             for _ in range(min(n_iters, 100)):
                 windows.append(sess.run(next_element)['mr_location'])
                 true_iters = true_iters + 1
         except (tf.errors.OutOfRangeError, EOFError):
             pass
         assert true_iters <= 100, 'keep the test smaller than 100 iters'
     return true_iters, np.concatenate(windows, 0)
示例#2
0
 def test_epoch(self):
     reader = get_2d_reader()
     batch_size = 3
     sampler = ImageWindowDataset(reader=reader,
                                  batch_size=batch_size,
                                  epoch=1)
     with self.cached_session() as sess:
         next_element = sampler.pop_batch_op()
         iters = 0
         try:
             for _ in range(400):
                 window = sess.run(next_element)
                 iters = iters + 1
         except tf.errors.OutOfRangeError:
             pass
         # batch size 3, 40 images in total
         self.assertEqual(
             np.ceil(reader.num_subjects / np.float(batch_size)), iters)