def test_close_early(self):
     sampler = WeightedSampler(reader=get_2d_reader(),
                               window_sizes=MOD_2D_DATA,
                               batch_size=2,
                               windows_per_image=10,
                               queue_length=10)
     sampler.close_all()
 def test_dynamic_init(self):
     sampler = WeightedSampler(reader=get_dynamic_window_reader(),
                               window_sizes=DYNAMIC_MOD_DATA,
                               batch_size=2,
                               windows_per_image=10,
                               queue_length=10)
     with self.cached_session() as sess:
         sampler.set_num_threads(2)
         out = sess.run(sampler.pop_batch_op())
         self.assertAllClose(out['image'].shape[1:], (8, 2, 256, 2))
 def test_ill_init(self):
     with self.assertRaisesRegexp(ValueError, ""):
         sampler = WeightedSampler(reader=get_3d_reader(),
                                   window_sizes=MOD_2D_DATA,
                                   batch_size=2,
                                   windows_per_image=10,
                                   queue_length=10)
def get_sampler(image_reader,
                patch_size,
                phase,
                windows_per_image=None,
                window_border=None):
    if phase in ('training', 'validation'):
        if windows_per_image:
            sampler = WeightedSampler(image_reader,
                                      window_sizes=patch_size,
                                      windows_per_image=windows_per_image)
        else:
            raise Exception('Invalid windows per image!')

    elif phase == 'inference':
        if window_border:
            sampler = GridSampler(image_reader,
                                  window_sizes=patch_size,
                                  window_border=window_border)
        else:
            raise Exception('Invalid window border!')

    else:
        raise Exception('Invalid phase choice: {}'.format(
            {'phase': ['training', 'validation', 'inference']}))
    return sampler
 def initialise_weighted_sampler(self):
     self.sampler = [[WeightedSampler(
         reader=reader,
         window_sizes=self.data_param,
         batch_size=self.net_param.batch_size,
         windows_per_image=self.action_param.sample_per_volume,
         queue_length=self.net_param.queue_length) for reader in
         self.readers]]
 def test_2d_init(self):
     sampler = WeightedSampler(reader=get_2d_reader(),
                               window_sizes=MOD_2D_DATA,
                               batch_size=2,
                               windows_per_image=10,
                               queue_length=10)
     with self.cached_session() as sess:
         sampler.set_num_threads(2)
         out = sess.run(sampler.pop_batch_op())
         self.assertAllClose(out['image'].shape, (2, 10, 9, 1))
     sampler.close_all()
Exemple #7
0
#             'filename_contains': ['middleSlice','mhd'], 'pixdim': (1,1), 'axcodes': ['L','P','S'],'interp_order': 3},
#             'sampler': {'path_to_search': './sample_images', 'spatial_window_size': (48,48,1),
#             'filename_contains': ['middleLabels','mhd'],'pixdim': (1,1), 'axcodes': ['L','P','S'],'interp_order': 0},
#               }

#Create image reader and add padding layer with 'constant' value 0
#This will ensure that samples are taken even at the beginning and ending slice
reader = ImageReader().initialise(data_param)
reader.add_preprocessing_layers(
    PadLayer(image_name=['MR', 'sampler'],
             border=(20, 20, 20),
             mode='constant'))
_, img, _ = reader(idx=0)

#Create samplers with window_size
weighted_sampler = WeightedSampler(reader, window_sizes=(48, 48, 48))

balanced_sampler = BalancedSampler(reader, window_sizes=(48, 48, 48))

uniform_sampler = UniformSampler(reader, window_sizes=(48, 48, 48))

#Generate N samples for each type
N = 30
import tensorflow as tf
# adding the tensorflow tensors
next_window = weighted_sampler.pop_batch_op()
# run the tensors
with tf.Session() as sess:
    weighted_sampler.run_threads(sess)  #initialise the iterator
    w_coords = []
    for _ in range(N):