Пример #1
0
 def test_name_mismatch(self):
     with self.assertRaisesRegexp(ValueError, ""):
         sampler = GridSampler(reader=get_dynamic_window_reader(),
                               window_sizes=MOD_2D_DATA,
                               batch_size=10,
                               spatial_window_size=None,
                               window_border=(0, 0, 0),
                               queue_length=10)
     with self.assertRaisesRegexp(ValueError, ""):
         sampler = GridSampler(reader=get_3d_reader(),
                               window_sizes=MOD_2D_DATA,
                               batch_size=10,
                               spatial_window_size=None,
                               window_border=(0, 0, 0),
                               queue_length=10)
Пример #2
0
    def test_filling(self):
        reader = get_nonnormalising_label_reader()
        test_constant = 0.5731
        postfix = '_niftynet_out_background'
        test_border = (10, 7, 8)
        data_param = MOD_LABEL_DATA
        sampler = GridSampler(reader=reader,
                              window_sizes=data_param,
                              batch_size=10,
                              spatial_window_size=None,
                              window_border=test_border,
                              queue_length=50)
        aggregator = GridSamplesAggregator(image_reader=reader,
                                           name='label',
                                           output_path=os.path.join(
                                               'testing_data', 'aggregated'),
                                           window_border=test_border,
                                           interp_order=0,
                                           postfix=postfix,
                                           fill_constant=test_constant)
        more_batch = True
        with self.test_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                out = sess.run(sampler.pop_batch_op())
                more_batch = aggregator.decode_batch(out['label'],
                                                     out['label_location'])
        output_filename = '{}{}.nii.gz'.format(
            sampler.reader.get_subject_id(0), postfix)
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)
        output_data = nib.load(output_file).get_data()[..., 0, 0]
        output_shape = output_data.shape
        for i in range(3):

            def _test_background(idcs):
                extract = output_data[idcs]
                self.assertTrue(
                    (extract == test_constant).sum() == extract.size)

            extract_idcs = [slice(None)] * 3

            extract_idcs[i] = slice(0, test_border[i])
            _test_background(tuple(extract_idcs))

            extract_idcs[i] = slice(output_shape[i] - test_border[i],
                                    output_shape[i])
            _test_background(tuple(extract_idcs))
Пример #3
0
def get_sampler(image_reader,
                patch_size,
                phase,
                windows_per_image=None,
                window_border=None):
    if phase in ('training', 'validation'):
        if windows_per_image:
            sampler = WeightedSampler(image_reader,
                                      window_sizes=patch_size,
                                      windows_per_image=windows_per_image)
        else:
            raise Exception('Invalid windows per image!')

    elif phase == 'inference':
        if window_border:
            sampler = GridSampler(image_reader,
                                  window_sizes=patch_size,
                                  window_border=window_border)
        else:
            raise Exception('Invalid window border!')

    else:
        raise Exception('Invalid phase choice: {}'.format(
            {'phase': ['training', 'validation', 'inference']}))
    return sampler
Пример #4
0
 def test_inverse_mapping(self):
     reader = get_label_reader()
     data_param = MOD_LABEL_DATA
     sampler = GridSampler(reader=reader,
                           window_sizes=data_param,
                           batch_size=10,
                           spatial_window_size=None,
                           window_border=(3, 4, 5),
                           queue_length=50)
     aggregator = GridSamplesAggregator(image_reader=reader,
                                        name='label',
                                        output_path=os.path.join(
                                            'testing_data', 'aggregated'),
                                        window_border=(3, 4, 5),
                                        interp_order=0)
     more_batch = True
     with self.test_session() as sess:
         sampler.set_num_threads(2)
         while more_batch:
             out = sess.run(sampler.pop_batch_op())
             more_batch = aggregator.decode_batch(out['label'],
                                                  out['label_location'])
     output_filename = '{}_niftynet_out.nii.gz'.format(
         sampler.reader.get_subject_id(0))
     output_file = os.path.join('testing_data', 'aggregated',
                                output_filename)
     self.assertAllClose(nib.load(output_file).shape, [256, 168, 256, 1, 1])
     sampler.close_all()
     output_data = nib.load(output_file).get_data()[..., 0, 0]
     expected_data = nib.load(
         'testing_data/T1_1023_NeuroMorph_Parcellation.nii.gz').get_data()
     self.assertAllClose(output_data, expected_data)
Пример #5
0
 def test_25d_init(self):
     reader = get_25d_reader()
     sampler = GridSampler(reader=reader,
                           window_sizes=SINGLE_25D_DATA,
                           batch_size=10,
                           spatial_window_size=None,
                           window_border=(3, 4, 5),
                           queue_length=50)
     aggregator = GridSamplesAggregator(image_reader=reader,
                                        name='image',
                                        output_path=os.path.join(
                                            'testing_data', 'aggregated'),
                                        window_border=(3, 4, 5),
                                        interp_order=0)
     more_batch = True
     with self.test_session() as sess:
         sampler.set_num_threads(2)
         while more_batch:
             out = sess.run(sampler.pop_batch_op())
             more_batch = aggregator.decode_batch(out['image'],
                                                  out['image_location'])
     output_filename = '{}_niftynet_out.nii.gz'.format(
         sampler.reader.get_subject_id(0))
     output_file = os.path.join('testing_data', 'aggregated',
                                output_filename)
     self.assertAllClose(nib.load(output_file).shape, [256, 168, 256, 1, 1],
                         rtol=1e-03,
                         atol=1e-03)
     sampler.close_all()
Пример #6
0
 def initialise_grid_sampler(self):
     self.sampler = [[GridSampler(
         reader=reader,
         window_sizes=self.data_param,
         batch_size=self.net_param.batch_size,
         spatial_window_size=self.action_param.spatial_window_size,
         window_border=self.action_param.border,
         queue_length=self.net_param.queue_length) for reader in
         self.readers]]
    def test_init_2d_mo_bidimcsv(self):
        reader = get_2d_reader()
        sampler = GridSampler(reader=reader,
                              window_sizes=MOD_2D_DATA,
                              batch_size=10,
                              spatial_window_size=None,
                              window_border=(3, 4, 5),
                              queue_length=50)
        aggregator = GridSamplesAggregator(image_reader=reader,
                                           name='image',
                                           output_path=os.path.join(
                                               'testing_data', 'aggregated'),
                                           window_border=(3, 4, 5),
                                           interp_order=0)
        more_batch = True

        with self.cached_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                out = sess.run(sampler.pop_batch_op())
                out_flatten = np.reshape(np.asarray(out['image']), [10, -1])
                min_val = np.sum(
                    np.reshape(np.asarray(out['image']), [10, -1]), 1)
                stats_val = np.concatenate([
                    np.min(out_flatten, 1, keepdims=True),
                    np.max(out_flatten, 1, keepdims=True),
                    np.sum(out_flatten, 1, keepdims=True)
                ], 1)
                stats_val = np.expand_dims(stats_val, 1)
                stats_val = np.concatenate([stats_val, stats_val], axis=1)
                more_batch = aggregator.decode_batch(
                    {
                        'window_image': out['image'],
                        'csv_sum': min_val,
                        'csv_stats_2d': stats_val
                    }, out['image_location'])
        output_filename = 'window_image_{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        sum_filename = os.path.join(
            'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        stats_filename = os.path.join(
            'testing_data', 'aggregated',
            'csv_stats_2d_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)

        self.assertAllClose(nib.load(output_file).shape, (128, 128))
        min_pd = pd.read_csv(sum_filename)
        self.assertAllClose(min_pd.shape, [10, 9])
        stats_pd = pd.read_csv(stats_filename)
        self.assertAllClose(stats_pd.shape, [10, 14])
        sampler.close_all()
Пример #8
0
 def test_dynamic_window_initialising(self):
     sampler = GridSampler(reader=get_dynamic_window_reader(),
                           window_sizes=DYNAMIC_MOD_DATA,
                           batch_size=10,
                           spatial_window_size=None,
                           window_border=(0, 0, 0),
                           queue_length=10)
     with self.cached_session() as sess:
         sampler.set_num_threads(1)
         out = sess.run(sampler.pop_batch_op())
         self.assertAllClose(out['image'].shape, (10, 8, 2, 256, 2))
     sampler.close_all()
Пример #9
0
 def test_25d_initialising(self):
     sampler = GridSampler(reader=get_3d_reader(),
                           window_sizes=MULTI_MOD_DATA,
                           batch_size=10,
                           spatial_window_size=(1, 20, 15),
                           window_border=(0, 0, 0),
                           queue_length=10)
     with self.cached_session() as sess:
         sampler.set_num_threads(2)
         out = sess.run(sampler.pop_batch_op())
         self.assertAllClose(out['image'].shape, (10, 20, 15, 2))
     sampler.close_all()
Пример #10
0
 def test_2d_initialising(self):
     sampler = GridSampler(reader=get_2d_reader(),
                           window_sizes=MOD_2D_DATA,
                           batch_size=10,
                           spatial_window_size=None,
                           window_border=(0, 0, 0),
                           queue_length=10)
     with self.test_session() as sess:
         sampler.set_num_threads(1)
         out = sess.run(sampler.pop_batch_op())
         self.assertAllClose(out['image'].shape, (10, 10, 7, 1))
     sampler.close_all()
Пример #11
0
def get_sampler(image_reader, patch_size, phase):
    if phase in ('training', 'validation'):
        sampler = UniformSampler(image_reader,
                                 window_sizes=patch_size,
                                 windows_per_image=2)
    elif phase == 'inference':
        sampler = GridSampler(image_reader,
                              window_sizes=patch_size,
                              window_border=(8, 8, 8),
                              batch_size=1)
    else:
        raise Exception('Invalid phase choice: {}'.format(
            {'phase': ['train', 'validation', 'inference']}))

    return sampler
    def test_3d_init_mo(self):
        reader = get_3d_reader()
        sampler = GridSampler(reader=reader,
                              window_sizes=MULTI_MOD_DATA,
                              batch_size=10,
                              spatial_window_size=None,
                              window_border=(3, 4, 5),
                              queue_length=50)
        aggregator = GridSamplesAggregator(image_reader=reader,
                                           name='image',
                                           output_path=os.path.join(
                                               'testing_data', 'aggregated'),
                                           window_border=(3, 4, 5),
                                           interp_order=0)
        more_batch = True

        with self.cached_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                out = sess.run(sampler.pop_batch_op())
                out_flatten = np.reshape(np.asarray(out['image']), [10, -1])
                min_val = np.sum(
                    np.reshape(np.asarray(out['image']), [10, -1]), 1)
                more_batch = aggregator.decode_batch(
                    {
                        'window_image': out['image'],
                        'csv_sum': min_val
                    }, out['image_location'])
        output_filename = 'window_image_{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        sum_filename = os.path.join(
            'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)

        self.assertAllClose(nib.load(output_file).shape, (256, 168, 256, 1, 2))
        min_pd = pd.read_csv(sum_filename)
        self.assertAllClose(min_pd.shape, [420, 9])
        sampler.close_all()
uniform_sampler = UniformSampler(reader,
                                 spatial_window_size,
                                 windows_per_image=100)
next_window = uniform_sampler.pop_batch_op()
coords = []
with tf.Session() as sess:
    for _ in range(20):
        uniform_windows = sess.run(next_window)
        coords.append(uniform_windows['MR_location'])
coords = np.concatenate(coords, axis=0)
vis_coordinates(image_2d, coords, 'output/uniform.png')

###
# create & show all grid samples
###
grid_sampler = GridSampler(reader, spatial_window_size, window_border=border)
next_grid = grid_sampler.pop_batch_op()
coords = []
with tf.Session() as sess:
    while True:
        window = sess.run(next_grid)
        if window['MR_location'][0, 0] == -1:
            break
        coords.append(window['MR_location'])
coords = np.concatenate(coords, axis=0)
vis_coordinates(image_2d, coords, 'output/grid.png')

###
# create & show cropped grid samples (in aggregator)
###
n_window = coords.shape[0]