Exemplo n.º 1
0
 def initialise_resize_aggregator(self):
     self.output_decoder = ResizeSamplesAggregator(
         image_reader=self.readers[0],
         output_path=self.action_param.save_seg_dir,
         window_border=self.action_param.border,
         interp_order=self.action_param.output_interp_order,
         postfix=self.action_param.output_postfix)
    def test_25d_init(self):
        reader = get_25d_reader()
        sampler = ResizeSampler(reader=reader,
                                window_sizes=SINGLE_25D_DATA,
                                batch_size=1,
                                shuffle=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(
            image_reader=reader,
            name='image',
            output_path=os.path.join('testing_data', 'aggregated'),
            interp_order=3)
        more_batch = True

        with self.test_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                try:
                    out = sess.run(sampler.pop_batch_op())
                except tf.errors.OutOfRangeError:
                    break
                more_batch = aggregator.decode_batch(
                    out['image'], out['image_location'])
        output_filename = '{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        output_file = os.path.join('testing_data',
                                   'aggregated',
                                   output_filename)
        self.assertAllClose(
            nib.load(output_file).shape, [255, 168, 256, 1, 1],
            rtol=1e-03, atol=1e-03)
        sampler.close_all()
Exemplo n.º 3
0
    def test_inverse_mapping(self):
        reader = get_label_reader()
        sampler = ResizeSampler(reader=reader,
                                data_param=MOD_LABEL_DATA,
                                batch_size=1,
                                shuffle_buffer=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(image_reader=reader,
                                             name='label',
                                             output_path=os.path.join(
                                                 'testing_data', 'aggregated'),
                                             interp_order=0)
        more_batch = True

        with self.test_session() as sess:
            coordinator = tf.train.Coordinator()
            sampler.run_threads(sess, coordinator, num_threads=2)
            while more_batch:
                out = sess.run(sampler.pop_batch_op())
                more_batch = aggregator.decode_batch(out['label'],
                                                     out['label_location'])
        output_filename = '{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)
        self.assertAllClose(nib.load(output_file).shape, [256, 168, 256, 1, 1])
        sampler.close_all()
Exemplo n.º 4
0
    def test_inverse_mapping(self):
        reader = get_label_reader()
        sampler = ResizeSampler(reader=reader,
                                window_sizes=MOD_LABEL_DATA,
                                batch_size=1,
                                shuffle=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(image_reader=reader,
                                             name='label',
                                             output_path=os.path.join(
                                                 'testing_data', 'aggregated'),
                                             interp_order=0)
        more_batch = True

        with self.cached_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                try:
                    out = sess.run(sampler.pop_batch_op())
                except tf.errors.OutOfRangeError:
                    break
                more_batch = aggregator.decode_batch(
                    {'window_label': out['label']}, out['label_location'])
        output_filename = 'window_label_{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)
        self.assertAllClose(nib.load(output_file).shape, [256, 168, 256])
        sampler.close_all()
Exemplo n.º 5
0
    def test_init_2d_mo_bidimcsv(self):
        reader = get_2d_reader()
        sampler = ResizeSampler(reader=reader,
                                window_sizes=MOD_2D_DATA,
                                batch_size=1,
                                shuffle=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(image_reader=reader,
                                             name='image',
                                             output_path=os.path.join(
                                                 'testing_data', 'aggregated'),
                                             interp_order=3)
        more_batch = True

        with self.cached_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                try:
                    out = sess.run(sampler.pop_batch_op())
                except tf.errors.OutOfRangeError:
                    break
                min_val = np.sum((np.asarray(out['image']).flatten()))
                stats_val = [
                    np.min(out['image']),
                    np.max(out['image']),
                    np.sum(out['image'])
                ]
                stats_val = np.expand_dims(stats_val, 0)
                stats_val = np.concatenate([stats_val, stats_val], axis=0)
                more_batch = aggregator.decode_batch(
                    {
                        'window_image': out['image'],
                        'csv_sum': min_val,
                        'csv_stats_2d': stats_val
                    }, out['image_location'])
        output_filename = 'window_image_{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        sum_filename = os.path.join(
            'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        stats_filename = os.path.join(
            'testing_data', 'aggregated',
            'csv_stats_2d_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)

        self.assertAllClose(nib.load(output_file).shape, (128, 128))
        min_pd = pd.read_csv(sum_filename)
        self.assertAllClose(min_pd.shape, [1, 2])
        stats_pd = pd.read_csv(stats_filename)
        self.assertAllClose(stats_pd.shape, [1, 7])
        sampler.close_all()
Exemplo n.º 6
0
 def initialise_resize_aggregator(self):
     '''
     Define the resize aggregator used for decoding using the
     configuration parameters
     :return:
     '''
     self.output_decoder = ResizeSamplesAggregator(
         image_reader=self.readers[0],
         output_path=self.action_param.save_seg_dir,
         window_border=self.action_param.border,
         interp_order=self.action_param.output_interp_order,
         postfix=self.action_param.output_postfix)
Exemplo n.º 7
0
    def test_3d_init_mo_3out(self):
        reader = get_3d_reader()
        sampler = ResizeSampler(reader=reader,
                                window_sizes=MULTI_MOD_DATA,
                                batch_size=1,
                                shuffle=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(image_reader=reader,
                                             name='image',
                                             output_path=os.path.join(
                                                 'testing_data', 'aggregated'),
                                             interp_order=3)
        more_batch = True

        with self.cached_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                try:
                    out = sess.run(sampler.pop_batch_op())
                except tf.errors.OutOfRangeError:
                    break
                sum_val = np.sum(out['image'])
                stats_val = [
                    np.sum(out['image']),
                    np.min(out['image']),
                    np.max(out['image'])
                ]
                more_batch = aggregator.decode_batch(
                    {
                        'window_image': out['image'],
                        'csv_sum': sum_val,
                        'csv_stats': stats_val
                    }, out['image_location'])
        output_filename = 'window_image_{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        sum_filename = os.path.join(
            'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        stats_filename = os.path.join(
            'testing_data', 'aggregated',
            'csv_stats_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)
        self.assertAllClose(nib.load(output_file).shape, (256, 168, 256, 1, 2))
        sum_pd = pd.read_csv(sum_filename)
        self.assertAllClose(sum_pd.shape, [1, 2])
        stats_pd = pd.read_csv(stats_filename)
        self.assertAllClose(stats_pd.shape, [1, 4])
        sampler.close_all()
Exemplo n.º 8
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_samplers(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0 if for_training else -1]
                return sampler()  # returns image only

        if self.is_training:
            self.patience = self.action_param.patience
            self.mode = self.action_param.early_stopping_mode
            if self.action_param.validation_every_n > 0:
                sampler_window = \
                    tf.cond(tf.logical_not(self.is_validation),
                            lambda: switch_samplers(True),
                            lambda: switch_samplers(False))
            else:
                sampler_window = switch_samplers(True)

            image_windows, _ = sampler_window
            # image_windows, locations = sampler_window

            # decode channels for moving and fixed images
            image_windows_list = [
                tf.expand_dims(img, axis=-1)
                for img in tf.unstack(image_windows, axis=-1)
            ]
            fixed_image, fixed_label, moving_image, moving_label = \
                image_windows_list

            # estimate ddf
            dense_field = self.net(fixed_image, moving_image)
            if isinstance(dense_field, tuple):
                dense_field = dense_field[0]

            # transform the moving labels
            resampler = ResamplerLayer(interpolation='linear',
                                       boundary='replicate')
            resampled_moving_label = resampler(moving_label, dense_field)

            # compute label loss (foreground only)
            loss_func = LossFunction(n_class=1,
                                     loss_type=self.action_param.loss_type,
                                     softmax=False)
            label_loss = loss_func(prediction=resampled_moving_label,
                                   ground_truth=fixed_label)

            dice_fg = 1.0 - label_loss
            # appending regularisation loss
            total_loss = label_loss
            reg_loss = tf.get_collection('bending_energy')
            if reg_loss:
                total_loss = total_loss + \
                    self.net_param.decay * tf.reduce_mean(reg_loss)

            self.total_loss = total_loss

            # compute training gradients
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            grads = self.optimiser.compute_gradients(
                total_loss, colocate_gradients_with_ops=True)
            gradients_collector.add_to_collection(grads)

            metrics_dice = loss_func(
                prediction=tf.to_float(resampled_moving_label >= 0.5),
                ground_truth=tf.to_float(fixed_label >= 0.5))
            metrics_dice = 1.0 - metrics_dice

            # command line output
            outputs_collector.add_to_collection(var=dice_fg,
                                                name='one_minus_data_loss',
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=tf.reduce_mean(reg_loss),
                                                name='bending_energy',
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=total_loss,
                                                name='total_loss',
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=metrics_dice,
                                                name='ave_fg_dice',
                                                collection=CONSOLE)

            # for tensorboard
            outputs_collector.add_to_collection(var=dice_fg,
                                                name='data_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=total_loss,
                                                name='total_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=metrics_dice,
                name='averaged_foreground_Dice',
                average_over_devices=True,
                summary_type='scalar',
                collection=TF_SUMMARIES)

            # for visualisation debugging
            # resampled_moving_image = resampler(moving_image, dense_field)
            # outputs_collector.add_to_collection(
            #     var=fixed_image, name='fixed_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=fixed_label, name='fixed_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=moving_image, name='moving_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=moving_label, name='moving_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=resampled_moving_image, name='resampled_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=resampled_moving_label, name='resampled_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=dense_field, name='ddf', collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=locations, name='locations', collection=NETWORK_OUTPUT)

            # outputs_collector.add_to_collection(
            #     var=shift[0], name='a', collection=CONSOLE)
            # outputs_collector.add_to_collection(
            #     var=shift[1], name='b', collection=CONSOLE)
        else:
            image_windows, locations = self.sampler()
            image_windows_list = [
                tf.expand_dims(img, axis=-1)
                for img in tf.unstack(image_windows, axis=-1)
            ]
            fixed_image, fixed_label, moving_image, moving_label = \
                image_windows_list

            dense_field = self.net(fixed_image, moving_image)
            if isinstance(dense_field, tuple):
                dense_field = dense_field[0]

            # transform the moving labels
            resampler = ResamplerLayer(interpolation='linear',
                                       boundary='replicate')
            resampled_moving_image = resampler(moving_image, dense_field)
            resampled_moving_label = resampler(moving_label, dense_field)

            outputs_collector.add_to_collection(var=fixed_image,
                                                name='fixed_image',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=moving_image,
                                                name='moving_image',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=resampled_moving_image,
                                                name='resampled_moving_image',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=resampled_moving_label,
                                                name='resampled_moving_label',
                                                collection=NETWORK_OUTPUT)

            outputs_collector.add_to_collection(var=fixed_label,
                                                name='fixed_label',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=moving_label,
                                                name='moving_label',
                                                collection=NETWORK_OUTPUT)
            #outputs_collector.add_to_collection(
            #    var=dense_field, name='field',
            #    collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=locations,
                                                name='locations',
                                                collection=NETWORK_OUTPUT)

            self.output_decoder = ResizeSamplesAggregator(
                image_reader=self.readers[0],  # fixed image reader
                name='fixed_image',
                output_path=self.action_param.save_seg_dir,
                interp_order=self.action_param.output_interp_order)