コード例 #1
0
 def test_25d_init(self):
     reader = get_25d_reader()
     sampler = GridSampler(reader=reader,
                           window_sizes=SINGLE_25D_DATA,
                           batch_size=10,
                           spatial_window_size=None,
                           window_border=(3, 4, 5),
                           queue_length=50)
     aggregator = GridSamplesAggregator(
         image_reader=reader,
         name='image',
         output_path=os.path.join('testing_data', 'aggregated'),
         window_border=(3, 4, 5),
         interp_order=0)
     more_batch = True
     with self.test_session() as sess:
         coordinator = tf.train.Coordinator()
         sampler.run_threads(sess, coordinator, num_threads=2)
         while more_batch:
             out = sess.run(sampler.pop_batch_op())
             more_batch = aggregator.decode_batch(
                 out['image'], out['image_location'])
     output_filename = '{}_niftynet_out.nii.gz'.format(
         sampler.reader.get_subject_id(0))
     output_file = os.path.join('testing_data',
                                'aggregated',
                                output_filename)
     self.assertAllClose(
         nib.load(output_file).shape, [255, 168, 256, 1, 1],
         rtol=1e-03, atol=1e-03)
     sampler.close_all()
コード例 #2
0
 def test_inverse_mapping(self):
     reader = get_label_reader()
     data_param = MOD_LABEL_DATA
     sampler = GridSampler(reader=reader,
                           window_sizes=data_param,
                           batch_size=10,
                           spatial_window_size=None,
                           window_border=(3, 4, 5),
                           queue_length=50)
     aggregator = GridSamplesAggregator(
         image_reader=reader,
         name='label',
         output_path=os.path.join('testing_data', 'aggregated'),
         window_border=(3, 4, 5),
         interp_order=0)
     more_batch = True
     with self.test_session() as sess:
         coordinator = tf.train.Coordinator()
         sampler.run_threads(sess, coordinator, num_threads=2)
         while more_batch:
             out = sess.run(sampler.pop_batch_op())
             more_batch = aggregator.decode_batch(
                 out['label'], out['label_location'])
     output_filename = '{}_niftynet_out.nii.gz'.format(
         sampler.reader.get_subject_id(0))
     output_file = os.path.join(
         'testing_data', 'aggregated', output_filename)
     self.assertAllClose(
         nib.load(output_file).shape, [256, 168, 256, 1, 1])
     sampler.close_all()
     output_data = nib.load(output_file).get_data()[..., 0, 0]
     expected_data = nib.load(
         'testing_data/T1_1023_NeuroMorph_Parcellation.nii.gz').get_data()
     self.assertAllClose(output_data, expected_data)
コード例 #3
0
 def test_inverse_mapping(self):
     reader = get_label_reader()
     data_param = MOD_LABEL_DATA
     sampler = GridSampler(reader=reader,
                           data_param=data_param,
                           batch_size=10,
                           spatial_window_size=None,
                           window_border=(3, 4, 5),
                           queue_length=50)
     aggregator = GridSamplesAggregator(
         image_reader=reader,
         name='label',
         output_path=os.path.join('testing_data', 'aggregated'),
         window_border=(3, 4, 5),
         interp_order=0)
     more_batch = True
     with self.test_session() as sess:
         coordinator = tf.train.Coordinator()
         sampler.run_threads(sess, coordinator, num_threads=2)
         while more_batch:
             out = sess.run(sampler.pop_batch_op())
             more_batch = aggregator.decode_batch(
                 out['label'], out['label_location'])
     output_filename = '{}_niftynet_out.nii.gz'.format(
         sampler.reader.get_subject_id(0))
     output_file = os.path.join(
         'testing_data', 'aggregated', output_filename)
     self.assertAllClose(
         nib.load(output_file).shape, [256, 168, 256, 1, 1])
     sampler.close_all()
     output_data = nib.load(output_file).get_data()[..., 0, 0]
     expected_data = nib.load(
         'testing_data/T1_1023_NeuroMorph_Parcellation.nii.gz').get_data()
     self.assertAllClose(output_data, expected_data)
コード例 #4
0
 def test_25d_init(self):
     reader = get_25d_reader()
     sampler = GridSampler(reader=reader,
                           data_param=SINGLE_25D_DATA,
                           batch_size=10,
                           spatial_window_size=None,
                           window_border=(3, 4, 5),
                           queue_length=50)
     aggregator = GridSamplesAggregator(
         image_reader=reader,
         name='image',
         output_path=os.path.join('testing_data', 'aggregated'),
         window_border=(3, 4, 5),
         interp_order=0)
     more_batch = True
     with self.test_session() as sess:
         coordinator = tf.train.Coordinator()
         sampler.run_threads(sess, coordinator, num_threads=2)
         while more_batch:
             out = sess.run(sampler.pop_batch_op())
             more_batch = aggregator.decode_batch(
                 out['image'], out['image_location'])
     output_filename = '{}_niftynet_out.nii.gz'.format(
         sampler.reader.get_subject_id(0))
     output_file = os.path.join('testing_data',
                                'aggregated',
                                output_filename)
     self.assertAllClose(
         nib.load(output_file).shape, [255, 168, 256, 1, 1],
         rtol=1e-03, atol=1e-03)
     sampler.close_all()
コード例 #5
0
 def test_2d_init(self):
     reader = get_2d_reader()
     sampler = GridSampler(reader=reader,
                           window_sizes=MOD_2D_DATA,
                           batch_size=10,
                           spatial_window_size=None,
                           window_border=(3, 4, 5),
                           queue_length=50)
     aggregator = GridSamplesAggregator(image_reader=reader,
                                        name='image',
                                        output_path=os.path.join(
                                            'testing_data', 'aggregated'),
                                        window_border=(3, 4, 5),
                                        interp_order=0)
     more_batch = True
     with self.cached_session() as sess:
         sampler.set_num_threads(2)
         while more_batch:
             out = sess.run(sampler.pop_batch_op())
             more_batch = aggregator.decode_batch(
                 {'window_image': out['image']}, out['image_location'])
     output_filename = 'window_image_{}_niftynet_out.nii.gz'.format(
         sampler.reader.get_subject_id(0))
     output_file = os.path.join('testing_data', 'aggregated',
                                output_filename)
     self.assertAllClose(nib.load(output_file).shape, [128, 128])
     sampler.close_all()
コード例 #6
0
    def test_init_2d_mo_bidimcsv(self):
        reader = get_2d_reader()
        sampler = GridSampler(reader=reader,
                              window_sizes=MOD_2D_DATA,
                              batch_size=10,
                              spatial_window_size=None,
                              window_border=(3, 4, 5),
                              queue_length=50)
        aggregator = GridSamplesAggregator(image_reader=reader,
                                           name='image',
                                           output_path=os.path.join(
                                               'testing_data', 'aggregated'),
                                           window_border=(3, 4, 5),
                                           interp_order=0)
        more_batch = True

        with self.cached_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                out = sess.run(sampler.pop_batch_op())
                out_flatten = np.reshape(np.asarray(out['image']), [10, -1])
                min_val = np.sum(
                    np.reshape(np.asarray(out['image']), [10, -1]), 1)
                stats_val = np.concatenate([
                    np.min(out_flatten, 1, keepdims=True),
                    np.max(out_flatten, 1, keepdims=True),
                    np.sum(out_flatten, 1, keepdims=True)
                ], 1)
                stats_val = np.expand_dims(stats_val, 1)
                stats_val = np.concatenate([stats_val, stats_val], axis=1)
                more_batch = aggregator.decode_batch(
                    {
                        'window_image': out['image'],
                        'csv_sum': min_val,
                        'csv_stats_2d': stats_val
                    }, out['image_location'])
        output_filename = 'window_image_{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        sum_filename = os.path.join(
            'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        stats_filename = os.path.join(
            'testing_data', 'aggregated',
            'csv_stats_2d_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)

        self.assertAllClose(nib.load(output_file).shape, (128, 128))
        min_pd = pd.read_csv(sum_filename)
        self.assertAllClose(min_pd.shape, [10, 9])
        stats_pd = pd.read_csv(stats_filename)
        self.assertAllClose(stats_pd.shape, [10, 14])
        sampler.close_all()
コード例 #7
0
    def run_inference(self, window_border, inference_path, checkpoint_path):

        output = GridSamplesAggregator(
            image_reader=self.samplers[INFERENCE].reader,
            window_border=window_border,
            interp_order=3,
            output_path=inference_path)

        self.model.load_state_dict(torch.load(checkpoint_path))
        self.model.to(self.device)
        self.model.eval()
        for batch_output in self.samplers[INFERENCE]():
            window = batch_output['image']
            # [...,0,:] eliminates time coordinate from NiftyNet Volume
            window = window[..., 0, :]
            window = np.transpose(window, (0, 4, 1, 2, 3))
            window = torch.Tensor(window).to(self.device)
            with torch.no_grad():
                outputs = self.model(window)
            outputs = outputs.cpu().numpy()
            outputs = np.transpose(outputs, (0, 2, 3, 4, 1))
            output.decode_batch(outputs, batch_output['image_location'])
コード例 #8
0
def inference(sampler, model, device, pred_path, cp_path):
    output = GridSamplesAggregator(image_reader=sampler.reader,
                                   window_border=(8, 8, 8),
                                   output_path=pred_path)
    for _ in sampler():  # for each subject

        model.load_state_dict(torch.load(cp_path))
        model.to(device)
        model.eval()

        for batch_output in sampler():  # for each sliding window step
            window = batch_output['image']
            # [...,0,:] eliminates time coordinate from NiftyNet Volume
            window = window[..., 0, :]
            window = np.transpose(window, (0, 4, 1, 2, 3))
            window = torch.Tensor(window).to(device)

            with torch.no_grad():
                outputs = model(window)

            outputs = outputs.cpu().numpy()
            outputs = np.transpose(outputs, (0, 2, 3, 4, 1))
            output.decode_batch(outputs.astype(np.float32),
                                batch_output['image_location'])
コード例 #9
0
    def test_filling(self):
        reader = get_nonnormalising_label_reader()
        test_constant = 0.5731
        postfix = '_niftynet_out_background'
        test_border = (10, 7, 8)
        data_param = MOD_LABEL_DATA
        sampler = GridSampler(reader=reader,
                              window_sizes=data_param,
                              batch_size=10,
                              spatial_window_size=None,
                              window_border=test_border,
                              queue_length=50)
        aggregator = GridSamplesAggregator(image_reader=reader,
                                           name='label',
                                           output_path=os.path.join(
                                               'testing_data', 'aggregated'),
                                           window_border=test_border,
                                           interp_order=0,
                                           postfix=postfix,
                                           fill_constant=test_constant)
        more_batch = True
        with self.test_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                out = sess.run(sampler.pop_batch_op())
                more_batch = aggregator.decode_batch(out['label'],
                                                     out['label_location'])
        output_filename = '{}{}.nii.gz'.format(
            sampler.reader.get_subject_id(0), postfix)
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)
        output_data = nib.load(output_file).get_data()[..., 0, 0]
        output_shape = output_data.shape
        for i in range(3):

            def _test_background(idcs):
                extract = output_data[idcs]
                self.assertTrue(
                    (extract == test_constant).sum() == extract.size)

            extract_idcs = [slice(None)] * 3

            extract_idcs[i] = slice(0, test_border[i])
            _test_background(tuple(extract_idcs))

            extract_idcs[i] = slice(output_shape[i] - test_border[i],
                                    output_shape[i])
            _test_background(tuple(extract_idcs))
コード例 #10
0
    def test_3d_init_mo(self):
        reader = get_3d_reader()
        sampler = GridSampler(reader=reader,
                              window_sizes=MULTI_MOD_DATA,
                              batch_size=10,
                              spatial_window_size=None,
                              window_border=(3, 4, 5),
                              queue_length=50)
        aggregator = GridSamplesAggregator(image_reader=reader,
                                           name='image',
                                           output_path=os.path.join(
                                               'testing_data', 'aggregated'),
                                           window_border=(3, 4, 5),
                                           interp_order=0)
        more_batch = True

        with self.cached_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                out = sess.run(sampler.pop_batch_op())
                out_flatten = np.reshape(np.asarray(out['image']), [10, -1])
                min_val = np.sum(
                    np.reshape(np.asarray(out['image']), [10, -1]), 1)
                more_batch = aggregator.decode_batch(
                    {
                        'window_image': out['image'],
                        'csv_sum': min_val
                    }, out['image_location'])
        output_filename = 'window_image_{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        sum_filename = os.path.join(
            'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)

        self.assertAllClose(nib.load(output_file).shape, (256, 168, 256, 1, 2))
        min_pd = pd.read_csv(sum_filename)
        self.assertAllClose(min_pd.shape, [420, 9])
        sampler.close_all()
コード例 #11
0
class BRATSApp(BaseApplication):
    REQUIRED_CONFIG_SECTION = "SEGMENTATION"

    def __init__(self, net_param, action_param, is_training):
        BaseApplication.__init__(self)
        tf.logging.info('starting BRATS segmentation app')
        self.is_training = is_training

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.segmentation_param = None

    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):
        self.data_param = data_param
        self.segmentation_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            file_lists = []
            if self.action_param.validation_every_n > 0:
                file_lists.append(data_partitioner.train_files)
                file_lists.append(data_partitioner.validation_files)
            else:
                file_lists.append(data_partitioner.all_files)
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(SUPPORTED_INPUT)
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        else:  # in the inference process use image input only
            inference_reader = ImageReader(['image'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)

        label_normaliser = DiscreteLabelNormalisationLayer(
            image_name='label',
            modalities=vars(task_param).get('label'),
            model_filename=self.net_param.histogram_ref_file)

        normalisation_layers = []
        normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation:
            normalisation_layers.append(label_normaliser)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(PadLayer(
                image_name=SUPPORTED_INPUT,
                border=self.net_param.volume_padding_size))
        for reader in self.readers:
            reader.add_preprocessing_layers(
                normalisation_layers + volume_padding_layer)

    def initialise_sampler(self):
        if self.is_training:
            self.sampler = [[UniformSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length) for reader in
                self.readers]]
        else:
            self.sampler = [[GridSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                spatial_window_size=self.action_param.spatial_window_size,
                window_border=self.action_param.border,
                queue_length=self.net_param.queue_length) for reader in
                self.readers]]

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=self.segmentation_param.num_classes,
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):

        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type)
            data_loss = loss_func(
                prediction=net_out,
                ground_truth=data_dict.get('label', None),
                weight_map=data_dict.get('weight', None))
            reg_losses = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(
                var=data_loss, name='dice_loss',
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=data_loss, name='dice_loss',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)
        else:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)

            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)

            outputs_collector.add_to_collection(
                var=net_out, name='window',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'], name='location',
                average_over_devices=False, collection=NETWORK_OUTPUT)

            self.output_decoder = GridSamplesAggregator(
                image_reader=self.readers[0],
                output_path=self.action_param.save_seg_dir,
                window_border=self.action_param.border,
                interp_order=self.action_param.output_interp_order)

    def interpret_output(self, batch_output):
        if not self.is_training:
            return self.output_decoder.decode_batch(
                batch_output['window'], batch_output['location'])
        return True
コード例 #12
0
class BRATSApp(BaseApplication):
    REQUIRED_CONFIG_SECTION = "SEGMENTATION"

    def __init__(self, net_param, action_param, is_training):
        BaseApplication.__init__(self)
        tf.logging.info('starting BRATS segmentation app')
        self.is_training = is_training

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.segmentation_param = None

    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):
        self.data_param = data_param
        self.segmentation_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            file_lists = []
            if self.action_param.validation_every_n > 0:
                file_lists.append(data_partitioner.train_files)
                file_lists.append(data_partitioner.validation_files)
            else:
                file_lists.append(data_partitioner.all_files)
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(SUPPORTED_INPUT)
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        else:  # in the inference process use image input only
            inference_reader = ImageReader(['image'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)

        label_normaliser = DiscreteLabelNormalisationLayer(
            image_name='label',
            modalities=vars(task_param).get('label'),
            model_filename=self.net_param.histogram_ref_file)

        normalisation_layers = []
        normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation:
            normalisation_layers.append(label_normaliser)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size))
        for reader in self.readers:
            reader.add_preprocessing_layers(normalisation_layers +
                                            volume_padding_layer)

    def initialise_sampler(self):
        if self.is_training:
            self.sampler = [[
                UniformSampler(
                    reader=reader,
                    data_param=self.data_param,
                    batch_size=self.net_param.batch_size,
                    windows_per_image=self.action_param.sample_per_volume,
                    queue_length=self.net_param.queue_length)
                for reader in self.readers
            ]]
        else:
            self.sampler = [[
                GridSampler(
                    reader=reader,
                    data_param=self.data_param,
                    batch_size=self.net_param.batch_size,
                    spatial_window_size=self.action_param.spatial_window_size,
                    window_border=self.action_param.border,
                    queue_length=self.net_param.queue_length)
                for reader in self.readers
            ]]

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=self.segmentation_param.num_classes,
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type)
            data_loss = loss_func(prediction=net_out,
                                  ground_truth=data_dict.get('label', None),
                                  weight_map=data_dict.get('weight', None))
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='dice_loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='dice_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
        else:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)

            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)

            self.output_decoder = GridSamplesAggregator(
                image_reader=self.readers[0],
                output_path=self.action_param.save_seg_dir,
                window_border=self.action_param.border,
                interp_order=self.action_param.output_interp_order)

    def interpret_output(self, batch_output):
        if not self.is_training:
            return self.output_decoder.decode_batch(batch_output['window'],
                                                    batch_output['location'])
        return True
コード例 #13
0
class SelectiveSampling(BaseApplication):
    REQUIRED_CONFIG_SECTION = "SEGMENTATION"

    def __init__(self, net_param, action_param, is_training):
        super(SelectiveSampling, self).__init__()
        tf.logging.info('starting segmentation application')
        self.is_training = is_training

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.segmentation_param = None
        self.SUPPORTED_SAMPLING = {
            'selective': (self.initialise_selective_sampler,
                          self.initialise_grid_sampler,
                          self.initialise_grid_aggregator)
        }

    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):

        self.data_param = data_param
        self.segmentation_param = task_param

        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(
            phase=reader_phase, action=self.action)
        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(SUPPORTED_INPUT)
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)

        else:  # in the inference process use image input only
            inference_reader = ImageReader(['image'])
            inference_reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [inference_reader]

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        label_normaliser = None
        if self.net_param.histogram_ref_file:
            label_normaliser = DiscreteLabelNormalisationLayer(
                image_name='label',
                modalities=vars(task_param).get('label'),
                model_filename=self.net_param.histogram_ref_file)

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation:
            normalisation_layers.append(label_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=self.action_param.scaling_percentage[0],
                    max_percentage=self.action_param.scaling_percentage[1]))
            if self.action_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if self.action_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        self.action_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)

        volume_padding_layer = [PadLayer(
            image_name=SUPPORTED_INPUT,
            border=self.net_param.volume_padding_size,
            mode=self.net_param.volume_padding_mode,
            pad_to=self.net_param.volume_padding_to_size)
        ]

        for reader in self.readers:
            reader.add_preprocessing_layers(
                volume_padding_layer +
                normalisation_layers +
                augmentation_layers)

    def initialise_selective_sampler(self):
        # print("Initialisation ",
        #       self.segmentation_param.compulsory_labels,
        #       self.segmentation_param.proba_connect)
        # print(self.segmentation_param.num_min_labels,
        #       self.segmentation_param.proba_connect)
        selective_constraints = Constraint(
            self.segmentation_param.compulsory_labels,
            self.segmentation_param.min_sampling_ratio,
            self.segmentation_param.min_numb_labels,
            self.segmentation_param.proba_connect)
        self.sampler = [[
            SelectiveSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                constraint=selective_constraints,
                random_windows_per_image=self.segmentation_param.rand_samples,
                queue_length=self.net_param.queue_length)
            for reader in self.readers]]

    def initialise_grid_sampler(self):
        self.sampler = [[GridSampler(
            reader=reader,
            window_sizes=self.data_param,
            batch_size=self.net_param.batch_size,
            spatial_window_size=self.action_param.spatial_window_size,
            window_border=self.action_param.border,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING['selective'][0]()
        else:
            self.SUPPORTED_SAMPLING['selective'][1]()

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=self.segmentation_param.num_classes,
            w_initializer=InitializerFactory.get_initializer(
                name=self.net_param.weight_initializer),
            b_initializer=InitializerFactory.get_initializer(
                name=self.net_param.bias_initializer),
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):

        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:

            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)
            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type)
            data_loss = loss_func(
                prediction=net_out,
                ground_truth=data_dict.get('label', None),
                weight_map=data_dict.get('weight', None))
            reg_losses = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(
                var=data_loss, name='dice_loss',
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=data_loss, name='dice_loss',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image*180.0, name='image',
            #    average_over_devices=False, summary_type='image3_sagittal',
            #    collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image, name='image',
            #    average_over_devices=False,
            #    collection=NETWORK_OUTPUT)

            # outputs_collector.add_to_collection(
            #    var=tf.reduce_mean(image), name='mean_image',
            #    average_over_devices=False, summary_type='scalar',
            #    collection=CONSOLE)
        else:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)

            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)

            outputs_collector.add_to_collection(
                var=net_out, name='window',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'], name='location',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            init_aggregator = \
                self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
            init_aggregator()

    def interpret_output(self, batch_output):
        if not self.is_training:
            return self.output_decoder.decode_batch(
                {'window_image': batch_output['window']},
                batch_output['location'])
        return True