Ejemplo n.º 1
0
    def test_2d_init(self):
        reader = get_2d_reader()
        sampler = ResizeSampler(reader=reader,
                                window_sizes=MOD_2D_DATA,
                                batch_size=1,
                                shuffle=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(
            image_reader=reader,
            name='image',
            output_path=os.path.join('testing_data', 'aggregated'),
            interp_order=3)
        more_batch = True

        with self.test_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                try:
                    out = sess.run(sampler.pop_batch_op())
                except tf.errors.OutOfRangeError:
                    break
                more_batch = aggregator.decode_batch(
                    out['image'], out['image_location'])
        output_filename = '{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        output_file = os.path.join('testing_data',
                                   'aggregated',
                                   output_filename)
        self.assertAllClose(
            nib.load(output_file).shape, [128, 128, 1, 1, 1])
        sampler.close_all()
Ejemplo n.º 2
0
 def test_25d_init(self):
     reader = get_25d_reader()
     sampler = ResizeSampler(reader=reader,
                             window_sizes=SINGLE_25D_DATA,
                             batch_size=1,
                             shuffle=False,
                             queue_length=50)
     aggregator = WindowAsImageAggregator(
         image_reader=reader,
         output_path=os.path.join('testing_data', 'aggregated_identity'),
     )
     more_batch = True
     out_shape = []
     with self.cached_session() as sess:
         sampler.set_num_threads(2)
         while more_batch:
             try:
                 out = sess.run(sampler.pop_batch_op())
                 out_shape = out['image'].shape[1:] + (1, )
             except tf.errors.OutOfRangeError:
                 break
             more_batch = aggregator.decode_batch(
                 {'window_image': out['image']}, out['image_location'])
     output_filename = '{}_window_image_niftynet_generated.nii.gz'.format(
         sampler.reader.get_subject_id(0))
     output_file = os.path.join('testing_data', 'aggregated_identity',
                                output_filename)
     out_shape = [out_shape[i] for i in NEW_ORDER_2D] + [
         1,
     ]
     self.assertAllClose(nib.load(output_file).shape, out_shape[:2])
     sampler.close_all()
 def initialise_resize_sampler(self):
     self.sampler = [[ResizeSampler(
         reader=reader,
         window_sizes=self.data_param,
         batch_size=self.net_param.batch_size,
         shuffle=self.is_training,
         smaller_final_batch_mode=self.net_param.smaller_final_batch_mode,
         queue_length=self.net_param.queue_length) for reader in
         self.readers]]
    def test_init_2d_mo_bidimcsv(self):
        reader = get_2d_reader()
        sampler = ResizeSampler(reader=reader,
                                window_sizes=MOD_2D_DATA,
                                batch_size=1,
                                shuffle=False,
                                queue_length=50)
        aggregator = WindowAsImageAggregator(
            image_reader=reader,
            output_path=os.path.join('testing_data', 'aggregated_identity'),
        )
        more_batch = True
        out_shape = []
        with self.cached_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                try:
                    out = sess.run(sampler.pop_batch_op())
                    out_shape = out['image'].shape[1:] + (1, )
                except tf.errors.OutOfRangeError:
                    break
                min_val = np.sum((np.asarray(out['image']).flatten()))
                stats_val = [
                    np.min(out['image']),
                    np.max(out['image']),
                    np.sum(out['image'])
                ]
                stats_val = np.expand_dims(stats_val, 0)
                stats_val = np.concatenate([stats_val, stats_val], axis=0)
                more_batch = aggregator.decode_batch(
                    {
                        'window_image': out['image'],
                        'csv_sum': min_val,
                        'csv_stats2d': stats_val
                    }, out['image_location'])
        output_filename = '{}_window_image_niftynet_generated.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        sum_filename = os.path.join(
            'testing_data', 'aggregated_identity',
            '{}_csv_sum_niftynet_generated.csv'.format(
                sampler.reader.get_subject_id(0)))
        stats_filename = os.path.join(
            'testing_data', 'aggregated_identity',
            '{}_csv_stats2d_niftynet_generated.csv'.format(
                sampler.reader.get_subject_id(0)))
        output_file = os.path.join('testing_data', 'aggregated_identity',
                                   output_filename)

        out_shape = [out_shape[i] for i in NEW_ORDER_2D] + [
            1,
        ]
        self.assertAllClose(nib.load(output_file).shape, out_shape[:2])
        min_pd = pd.read_csv(sum_filename)
        self.assertAllClose(min_pd.shape, [1, 2])
        stats_pd = pd.read_csv(stats_filename)
        self.assertAllClose(stats_pd.shape, [1, 7])
        sampler.close_all()
Ejemplo n.º 5
0
 def test_2d_init(self):
     sampler = ResizeSampler(reader=get_2d_reader(),
                             window_sizes=MOD_2D_DATA,
                             batch_size=1,
                             shuffle=True,
                             queue_length=1)
     with self.cached_session() as sess:
         sampler.set_num_threads(2)
         out = sess.run(sampler.pop_batch_op())
         self.assertAllClose(out['image'].shape, [1, 10, 9, 1])
     sampler.close_all()
Ejemplo n.º 6
0
 def test_dynamic_init(self):
     sampler = ResizeSampler(reader=get_dynamic_window_reader(),
                             window_sizes=DYNAMIC_MOD_DATA,
                             batch_size=1,
                             shuffle=False,
                             queue_length=1)
     with self.cached_session() as sess:
         sampler.set_num_threads(2)
         out = sess.run(sampler.pop_batch_op())
         self.assertAllClose(out['image'].shape, [1, 8, 2, 256, 2])
     sampler.close_all()
Ejemplo n.º 7
0
 def test_3d_init(self):
     sampler = ResizeSampler(reader=get_3d_reader(),
                             window_sizes=MULTI_MOD_DATA,
                             batch_size=1,
                             shuffle=False,
                             queue_length=1)
     with self.test_session() as sess:
         sampler.set_num_threads(2)
         out = sess.run(sampler.pop_batch_op())
         self.assertAllClose(out['image'].shape, [1, 7, 10, 2, 2])
     sampler.close_all()
Ejemplo n.º 8
0
    def test_init_3d_mo_bidimcsv(self):
        reader = get_3d_reader()
        sampler = ResizeSampler(reader=reader,
                                window_sizes=MULTI_MOD_DATA,
                                batch_size=1,
                                shuffle=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(image_reader=reader,
                                             name='image',
                                             output_path=os.path.join(
                                                 'testing_data', 'aggregated'),
                                             interp_order=3)
        more_batch = True

        with self.cached_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                try:
                    out = sess.run(sampler.pop_batch_op())
                except tf.errors.OutOfRangeError:
                    break
                min_val = np.sum((np.asarray(out['image']).flatten()))
                stats_val = [
                    np.min(out['image']),
                    np.max(out['image']),
                    np.sum(out['image'])
                ]
                stats_val = np.expand_dims(stats_val, 0)
                stats_val = np.concatenate([stats_val, stats_val], axis=0)
                more_batch = aggregator.decode_batch(
                    {
                        'window_image': out['image'],
                        'csv_sum': min_val,
                        'csv_stats_2d': stats_val
                    }, out['image_location'])
        output_filename = 'window_image_{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        sum_filename = os.path.join(
            'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        stats_filename = os.path.join(
            'testing_data', 'aggregated',
            'csv_stats_2d_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)

        self.assertAllClose(nib.load(output_file).shape, (256, 168, 256, 1, 2))
        min_pd = pd.read_csv(sum_filename)
        self.assertAllClose(min_pd.shape, [1, 2])
        stats_pd = pd.read_csv(stats_filename)
        self.assertAllClose(stats_pd.shape, [1, 7])
        sampler.close_all()
Ejemplo n.º 9
0
 def initialise_sampler(self):
     self.sampler = []
     if self.is_training:
         self.sampler.append([
             ResizeSampler(reader=reader,
                           csv_reader=None,
                           window_sizes=self.data_param,
                           batch_size=self.net_param.batch_size,
                           windows_per_image=1,
                           shuffle=True,
                           queue_length=self.net_param.queue_length)
             for reader in self.readers
         ])
         return
     if self._infer_type in ('encode', 'encode-decode'):
         self.sampler.append([
             ResizeSampler(reader=reader,
                           csv_reader=None,
                           window_sizes=self.data_param,
                           batch_size=self.net_param.batch_size,
                           windows_per_image=1,
                           shuffle=False,
                           queue_length=self.net_param.queue_length)
             for reader in self.readers
         ])
         return
     if self._infer_type == 'linear_interpolation':
         self.sampler.append([
             LinearInterpolateSampler(
                 reader=reader,
                 csv_reader=None,
                 window_sizes=self.data_param,
                 batch_size=self.net_param.batch_size,
                 n_interpolations=self.autoencoder_param.n_interpolations,
                 queue_length=self.net_param.queue_length)
             for reader in self.readers
         ])
         return
 def initialise_sampler(self):
     self.sampler = []
     if self.is_training:
         self.sampler.append([
             ResizeSampler(reader=reader,
                           window_sizes=self.data_param,
                           batch_size=self.net_param.batch_size,
                           windows_per_image=1,
                           shuffle=True,
                           queue_length=self.net_param.queue_length)
             for reader in self.readers
         ])
     else:
         self.sampler.append([
             RandomVectorSampler(
                 names=('vector', ),
                 vector_size=(self.gan_param.noise_size, ),
                 batch_size=self.net_param.batch_size,
                 n_interpolations=self.gan_param.n_interpolations,
                 repeat=None,
                 queue_length=self.net_param.queue_length)
             for _ in self.readers
         ])
         # repeat each resized image n times, so that each
         # image matches one random vector,
         # (n = self.gan_param.n_interpolations)
         self.sampler.append([
             ResizeSampler(
                 reader=reader,
                 window_sizes=self.data_param,
                 batch_size=self.net_param.batch_size,
                 windows_per_image=self.gan_param.n_interpolations,
                 shuffle=False,
                 queue_length=self.net_param.queue_length)
             for reader in self.readers
         ])