示例#1
0
 def test_name_mismatch(self):
     with self.assertRaisesRegexp(KeyError, ""):
         sampler = GridSampler(reader=get_dynamic_window_reader(),
                               data_param=MOD_2D_DATA,
                               batch_size=10,
                               spatial_window_size=None,
                               window_border=(0, 0, 0),
                               queue_length=10)
     with self.assertRaisesRegexp(KeyError, ""):
         sampler = GridSampler(reader=get_3d_reader(),
                               data_param=MOD_2D_DATA,
                               batch_size=10,
                               spatial_window_size=None,
                               window_border=(0, 0, 0),
                               queue_length=10)
 def test_inverse_mapping(self):
     reader = get_label_reader()
     data_param = MOD_LABEL_DATA
     sampler = GridSampler(reader=reader,
                           data_param=data_param,
                           batch_size=10,
                           spatial_window_size=None,
                           window_border=(3, 4, 5),
                           queue_length=50)
     aggregator = GridSamplesAggregator(image_reader=reader,
                                        name='label',
                                        output_path=os.path.join(
                                            'testing_data', 'aggregated'),
                                        window_border=(3, 4, 5),
                                        interp_order=0)
     more_batch = True
     with self.test_session() as sess:
         coordinator = tf.train.Coordinator()
         sampler.run_threads(sess, coordinator, num_threads=2)
         while more_batch:
             out = sess.run(sampler.pop_batch_op())
             more_batch = aggregator.decode_batch(out['label'],
                                                  out['label_location'])
     output_filename = '{}_niftynet_out.nii.gz'.format(
         sampler.reader.get_subject_id(0))
     output_file = os.path.join('testing_data', 'aggregated',
                                output_filename)
     self.assertAllClose(nib.load(output_file).shape, [256, 168, 256, 1, 1])
     sampler.close_all()
     output_data = nib.load(output_file).get_data()[..., 0, 0]
     expected_data = nib.load(
         'testing_data/T1_1023_NeuroMorph_Parcellation.nii.gz').get_data()
     self.assertAllClose(output_data, expected_data)
 def test_25d_init(self):
     reader = get_25d_reader()
     sampler = GridSampler(reader=reader,
                           data_param=SINGLE_25D_DATA,
                           batch_size=10,
                           spatial_window_size=None,
                           window_border=(3, 4, 5),
                           queue_length=50)
     aggregator = GridSamplesAggregator(image_reader=reader,
                                        name='image',
                                        output_path=os.path.join(
                                            'testing_data', 'aggregated'),
                                        window_border=(3, 4, 5),
                                        interp_order=0)
     more_batch = True
     with self.test_session() as sess:
         coordinator = tf.train.Coordinator()
         sampler.run_threads(sess, coordinator, num_threads=2)
         while more_batch:
             out = sess.run(sampler.pop_batch_op())
             more_batch = aggregator.decode_batch(out['image'],
                                                  out['image_location'])
     output_filename = '{}_niftynet_out.nii.gz'.format(
         sampler.reader.get_subject_id(0))
     output_file = os.path.join('testing_data', 'aggregated',
                                output_filename)
     self.assertAllClose(nib.load(output_file).shape, [255, 168, 256, 1, 1],
                         rtol=1e-03,
                         atol=1e-03)
     sampler.close_all()
示例#4
0
 def initialise_grid_sampler(self):
     self.sampler = [[GridSampler(
         reader=reader,
         data_param=self.data_param,
         batch_size=self.net_param.batch_size,
         spatial_window_size=self.action_param.spatial_window_size,
         window_border=self.action_param.border,
         queue_length=self.net_param.queue_length) for reader in
         self.readers]]
示例#5
0
 def test_3d_initialising(self):
     sampler = GridSampler(reader=get_3d_reader(),
                           data_param=MULTI_MOD_DATA,
                           batch_size=10,
                           spatial_window_size=None,
                           window_border=(0, 0, 0),
                           queue_length=10)
     with self.test_session() as sess:
         coordinator = tf.train.Coordinator()
         sampler.run_threads(sess, coordinator, num_threads=2)
         out = sess.run(sampler.pop_batch_op())
         self.assertAllClose(out['image'].shape, (10, 8, 10, 2, 2))
     sampler.close_all()
示例#6
0
 def initialise_sampler(self):
     if self.is_training:
         self.sampler = [[UniformSampler(
             reader=reader,
             data_param=self.data_param,
             batch_size=self.net_param.batch_size,
             windows_per_image=self.action_param.sample_per_volume,
             queue_length=self.net_param.queue_length) for reader in
             self.readers]]
     else:
         self.sampler = [[GridSampler(
             reader=reader,
             data_param=self.data_param,
             batch_size=self.net_param.batch_size,
             spatial_window_size=self.action_param.spatial_window_size,
             window_border=self.action_param.border,
             queue_length=self.net_param.queue_length) for reader in
             self.readers]]