コード例 #1
0
 def initialise_sampler(self):
     self.sampler = []
     if self.is_training:
         self.sampler.append([
             ResizeSampler(reader=reader,
                           data_param=self.data_param,
                           batch_size=self.net_param.batch_size,
                           windows_per_image=1,
                           shuffle_buffer=True,
                           queue_length=self.net_param.queue_length)
             for reader in self.readers
         ])
         return
     if self._infer_type in ('encode', 'encode-decode'):
         self.sampler.append([
             ResizeSampler(reader=reader,
                           data_param=self.data_param,
                           batch_size=self.net_param.batch_size,
                           windows_per_image=1,
                           shuffle_buffer=False,
                           queue_length=self.net_param.queue_length)
             for reader in self.readers
         ])
         return
     if self._infer_type == 'linear_interpolation':
         self.sampler.append([
             LinearInterpolateSampler(
                 reader=reader,
                 data_param=self.data_param,
                 batch_size=self.net_param.batch_size,
                 n_interpolations=self.autoencoder_param.n_interpolations,
                 queue_length=self.net_param.queue_length)
             for reader in self.readers
         ])
         return
コード例 #2
0
 def initialise_sampler(self):
     self.sampler = []
     if self.is_training:
         self.sampler.append(ResizeSampler(
             reader=self.reader,
             data_param=self.data_param,
             batch_size=self.net_param.batch_size,
             windows_per_image=1,
             shuffle_buffer=True,
             queue_length=self.net_param.queue_length))
     else:
         self.sampler.append(RandomVectorSampler(
             names=('vector',),
             vector_size=(self.gan_param.noise_size,),
             batch_size=self.net_param.batch_size,
             n_interpolations=self.gan_param.n_interpolations,
             repeat=None,
             queue_length=self.net_param.queue_length))
         # repeat each resized image n times, so that each
         # image matches one random vector,
         # (n = self.gan_param.n_interpolations)
         self.sampler.append(ResizeSampler(
             reader=self.reader,
             data_param=self.data_param,
             batch_size=self.net_param.batch_size,
             windows_per_image=self.gan_param.n_interpolations,
             shuffle_buffer=False,
             queue_length=self.net_param.queue_length))
コード例 #3
0
 def initialise_resize_sampler(self):
     self.sampler = [[ResizeSampler(
         reader=reader,
         data_param=self.data_param,
         batch_size=self.net_param.batch_size,
         shuffle_buffer=self.is_training,
         queue_length=self.net_param.queue_length) for reader in
         self.readers]]
コード例 #4
0
    def test_inverse_mapping(self):
        reader = get_label_reader()
        sampler = ResizeSampler(reader=reader,
                                data_param=MOD_LABEL_DATA,
                                batch_size=1,
                                shuffle_buffer=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(image_reader=reader,
                                             name='label',
                                             output_path=os.path.join(
                                                 'testing_data', 'aggregated'),
                                             interp_order=0)
        more_batch = True

        with self.test_session() as sess:
            coordinator = tf.train.Coordinator()
            sampler.run_threads(sess, coordinator, num_threads=2)
            while more_batch:
                out = sess.run(sampler.pop_batch_op())
                more_batch = aggregator.decode_batch(out['label'],
                                                     out['label_location'])
        output_filename = '{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)
        self.assertAllClose(nib.load(output_file).shape, [256, 168, 256, 1, 1])
        sampler.close_all()
コード例 #5
0
    def test_inverse_mapping(self):
        reader = get_label_reader()
        sampler = ResizeSampler(reader=reader,
                                data_param=MOD_LABEL_DATA,
                                batch_size=1,
                                shuffle_buffer=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(
            image_reader=reader,
            name='label',
            output_path=os.path.join('testing_data', 'aggregated'),
            interp_order=0)
        more_batch = True

        with self.test_session() as sess:
            coordinator = tf.train.Coordinator()
            sampler.run_threads(sess, coordinator, num_threads=2)
            while more_batch:
                out = sess.run(sampler.pop_batch_op())
                more_batch = aggregator.decode_batch(
                    out['label'], out['label_location'])
        output_filename = '{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        output_file = os.path.join(
            'testing_data', 'aggregated', output_filename)
        self.assertAllClose(
            nib.load(output_file).shape, [256, 168, 256, 1, 1])
        sampler.close_all()
コード例 #6
0
 def test_3d_init(self):
     sampler = ResizeSampler(reader=get_3d_reader(),
                             data_param=MULTI_MOD_DATA,
                             batch_size=1,
                             shuffle_buffer=False,
                             queue_length=1)
     with self.test_session() as sess:
         coordinator = tf.train.Coordinator()
         sampler.run_threads(sess, coordinator, num_threads=2)
         out = sess.run(sampler.pop_batch_op())
         self.assertAllClose(out['image'].shape, [1, 7, 10, 2, 2])
     sampler.close_all()
コード例 #7
0
 def test_3d_init(self):
     sampler = ResizeSampler(
         reader=get_3d_reader(),
         data_param=MULTI_MOD_DATA,
         batch_size=1,
         shuffle_buffer=False,
         queue_length=1)
     with self.test_session() as sess:
         coordinator = tf.train.Coordinator()
         sampler.run_threads(sess, coordinator, num_threads=2)
         out = sess.run(sampler.pop_batch_op())
         self.assertAllClose(out['image'].shape, [1, 7, 10, 2, 2])
     sampler.close_all()