def test_25d_init(self):
        reader = get_25d_reader()
        sampler = ResizeSampler(reader=reader,
                                window_sizes=SINGLE_25D_DATA,
                                batch_size=1,
                                shuffle=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(
            image_reader=reader,
            name='image',
            output_path=os.path.join('testing_data', 'aggregated'),
            interp_order=3)
        more_batch = True

        with self.test_session() as sess:
            coordinator = tf.train.Coordinator()
            sampler.run_threads(sess, coordinator, num_threads=2)
            while more_batch:
                try:
                    out = sess.run(sampler.pop_batch_op())
                except tf.errors.OutOfRangeError:
                    break
                more_batch = aggregator.decode_batch(
                    out['image'], out['image_location'])
        output_filename = '{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        output_file = os.path.join('testing_data',
                                   'aggregated',
                                   output_filename)
        self.assertAllClose(
            nib.load(output_file).shape, [255, 168, 256, 1, 1],
            rtol=1e-03, atol=1e-03)
        sampler.close_all()
    def test_inverse_mapping(self):
        reader = get_label_reader()
        sampler = ResizeSampler(reader=reader,
                                data_param=MOD_LABEL_DATA,
                                batch_size=1,
                                shuffle_buffer=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(
            image_reader=reader,
            name='label',
            output_path=os.path.join('testing_data', 'aggregated'),
            interp_order=0)
        more_batch = True

        with self.test_session() as sess:
            coordinator = tf.train.Coordinator()
            sampler.run_threads(sess, coordinator, num_threads=2)
            while more_batch:
                out = sess.run(sampler.pop_batch_op())
                more_batch = aggregator.decode_batch(
                    out['label'], out['label_location'])
        output_filename = '{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        output_file = os.path.join(
            'testing_data', 'aggregated', output_filename)
        self.assertAllClose(
            nib.load(output_file).shape, [256, 168, 256, 1, 1])
        sampler.close_all()
Esempio n. 3
0
    def test_inverse_mapping(self):
        reader = get_label_reader()
        sampler = ResizeSampler(reader=reader,
                                data_param=MOD_LABEL_DATA,
                                batch_size=1,
                                shuffle_buffer=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(image_reader=reader,
                                             name='label',
                                             output_path=os.path.join(
                                                 'testing_data', 'aggregated'),
                                             interp_order=0)
        more_batch = True

        with self.test_session() as sess:
            coordinator = tf.train.Coordinator()
            sampler.run_threads(sess, coordinator, num_threads=2)
            while more_batch:
                out = sess.run(sampler.pop_batch_op())
                more_batch = aggregator.decode_batch(out['label'],
                                                     out['label_location'])
        output_filename = '{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)
        self.assertAllClose(nib.load(output_file).shape, [256, 168, 256, 1, 1])
        sampler.close_all()
    def test_25d_init(self):
        reader = get_25d_reader()
        sampler = ResizeSampler(reader=reader,
                                window_sizes=SINGLE_25D_DATA,
                                batch_size=1,
                                shuffle=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(
            image_reader=reader,
            name='image',
            output_path=os.path.join('testing_data', 'aggregated'),
            interp_order=3)
        more_batch = True

        with self.test_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                try:
                    out = sess.run(sampler.pop_batch_op())
                except tf.errors.OutOfRangeError:
                    break
                more_batch = aggregator.decode_batch(
                    out['image'], out['image_location'])
        output_filename = '{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        output_file = os.path.join('testing_data',
                                   'aggregated',
                                   output_filename)
        self.assertAllClose(
            nib.load(output_file).shape, [255, 168, 256, 1, 1],
            rtol=1e-03, atol=1e-03)
        sampler.close_all()
Esempio n. 5
0
    def test_inverse_mapping(self):
        reader = get_label_reader()
        sampler = ResizeSampler(reader=reader,
                                window_sizes=MOD_LABEL_DATA,
                                batch_size=1,
                                shuffle=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(image_reader=reader,
                                             name='label',
                                             output_path=os.path.join(
                                                 'testing_data', 'aggregated'),
                                             interp_order=0)
        more_batch = True

        with self.cached_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                try:
                    out = sess.run(sampler.pop_batch_op())
                except tf.errors.OutOfRangeError:
                    break
                more_batch = aggregator.decode_batch(
                    {'window_label': out['label']}, out['label_location'])
        output_filename = 'window_label_{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)
        self.assertAllClose(nib.load(output_file).shape, [256, 168, 256])
        sampler.close_all()
 def initialise_resize_aggregator(self):
     self.output_decoder = ResizeSamplesAggregator(
         image_reader=self.readers[0],
         output_path=self.action_param.save_seg_dir,
         window_border=self.action_param.border,
         interp_order=self.action_param.output_interp_order,
         postfix=self.action_param.output_postfix)
Esempio n. 7
0
    def test_init_2d_mo_bidimcsv(self):
        reader = get_2d_reader()
        sampler = ResizeSampler(reader=reader,
                                window_sizes=MOD_2D_DATA,
                                batch_size=1,
                                shuffle=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(image_reader=reader,
                                             name='image',
                                             output_path=os.path.join(
                                                 'testing_data', 'aggregated'),
                                             interp_order=3)
        more_batch = True

        with self.cached_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                try:
                    out = sess.run(sampler.pop_batch_op())
                except tf.errors.OutOfRangeError:
                    break
                min_val = np.sum((np.asarray(out['image']).flatten()))
                stats_val = [
                    np.min(out['image']),
                    np.max(out['image']),
                    np.sum(out['image'])
                ]
                stats_val = np.expand_dims(stats_val, 0)
                stats_val = np.concatenate([stats_val, stats_val], axis=0)
                more_batch = aggregator.decode_batch(
                    {
                        'window_image': out['image'],
                        'csv_sum': min_val,
                        'csv_stats_2d': stats_val
                    }, out['image_location'])
        output_filename = 'window_image_{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        sum_filename = os.path.join(
            'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        stats_filename = os.path.join(
            'testing_data', 'aggregated',
            'csv_stats_2d_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)

        self.assertAllClose(nib.load(output_file).shape, (128, 128))
        min_pd = pd.read_csv(sum_filename)
        self.assertAllClose(min_pd.shape, [1, 2])
        stats_pd = pd.read_csv(stats_filename)
        self.assertAllClose(stats_pd.shape, [1, 7])
        sampler.close_all()
Esempio n. 8
0
 def initialise_resize_aggregator(self):
     '''
     Define the resize aggregator used for decoding using the
     configuration parameters
     :return:
     '''
     self.output_decoder = ResizeSamplesAggregator(
         image_reader=self.readers[0],
         output_path=self.action_param.save_seg_dir,
         window_border=self.action_param.border,
         interp_order=self.action_param.output_interp_order,
         postfix=self.action_param.output_postfix)
Esempio n. 9
0
    def test_3d_init_mo_3out(self):
        reader = get_3d_reader()
        sampler = ResizeSampler(reader=reader,
                                window_sizes=MULTI_MOD_DATA,
                                batch_size=1,
                                shuffle=False,
                                queue_length=50)
        aggregator = ResizeSamplesAggregator(image_reader=reader,
                                             name='image',
                                             output_path=os.path.join(
                                                 'testing_data', 'aggregated'),
                                             interp_order=3)
        more_batch = True

        with self.cached_session() as sess:
            sampler.set_num_threads(2)
            while more_batch:
                try:
                    out = sess.run(sampler.pop_batch_op())
                except tf.errors.OutOfRangeError:
                    break
                sum_val = np.sum(out['image'])
                stats_val = [
                    np.sum(out['image']),
                    np.min(out['image']),
                    np.max(out['image'])
                ]
                more_batch = aggregator.decode_batch(
                    {
                        'window_image': out['image'],
                        'csv_sum': sum_val,
                        'csv_stats': stats_val
                    }, out['image_location'])
        output_filename = 'window_image_{}_niftynet_out.nii.gz'.format(
            sampler.reader.get_subject_id(0))
        sum_filename = os.path.join(
            'testing_data', 'aggregated', 'csv_sum_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        stats_filename = os.path.join(
            'testing_data', 'aggregated',
            'csv_stats_{}_niftynet_out.csv'.format(
                sampler.reader.get_subject_id(0)))
        output_file = os.path.join('testing_data', 'aggregated',
                                   output_filename)
        self.assertAllClose(nib.load(output_file).shape, (256, 168, 256, 1, 2))
        sum_pd = pd.read_csv(sum_filename)
        self.assertAllClose(sum_pd.shape, [1, 2])
        stats_pd = pd.read_csv(stats_filename)
        self.assertAllClose(stats_pd.shape, [1, 4])
        sampler.close_all()
Esempio n. 10
0
 def initialise_grid_aggregator(self):
     self.output_decoder = GridSamplesAggregator(
         image_reader=self.readers[0],
         output_path=self.action_param.save_seg_dir,
         window_border=self.action_param.border,
         interp_order=self.action_param.output_interp_order,
         postfix=self.action_param.output_postfix,
         fill_constant=self.action_param.fill_constant)
Esempio n. 11
0
 def initialise_grid_aggregator(self):
     '''
     Define the grid aggregator used for decoding using configuration
     parameters
     :return:
     '''
     self.output_decoder = GridSamplesAggregator(
         image_reader=self.readers[0],
         output_path=self.action_param.save_seg_dir,
         window_border=self.action_param.border,
         interp_order=self.action_param.output_interp_order,
         postfix=self.action_param.output_postfix,
         fill_constant=self.action_param.fill_constant)
Esempio n. 12
0
class SegmentationApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "SEGMENTATION"

    def __init__(self, net_param, action_param, action):
        super(SegmentationApplication, self).__init__()
        tf.logging.info('starting segmentation application')
        self.action = action

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.segmentation_param = None
        self.SUPPORTED_SAMPLING = {
            'uniform': (self.initialise_uniform_sampler,
                        self.initialise_grid_sampler,
                        self.initialise_grid_aggregator),
            'weighted': (self.initialise_weighted_sampler,
                         self.initialise_grid_sampler,
                         self.initialise_grid_aggregator),
            'resize': (self.initialise_resize_sampler,
                       self.initialise_resize_sampler,
                       self.initialise_resize_aggregator),
            'balanced': (self.initialise_balanced_sampler,
                         self.initialise_grid_sampler,
                         self.initialise_grid_aggregator),
        }

    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):

        self.data_param = data_param
        self.segmentation_param = task_param

        file_lists = self.get_file_lists(data_partitioner)
        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader({'image', 'label', 'weight', 'sampler'})
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)

        elif self.is_inference:
            # in the inference process use image input only
            inference_reader = ImageReader({'image'})
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]
        elif self.is_evaluation:
            file_list = data_partitioner.inference_files
            reader = ImageReader({'image', 'label', 'inferred'})
            reader.initialise(data_param, task_param, file_list)
            self.readers = [reader]
        else:
            raise ValueError('Action `{}` not supported. Expected one of {}'
                             .format(self.action, self.SUPPORTED_ACTIONS))

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        label_normalisers = None
        if self.net_param.histogram_ref_file and \
                task_param.label_normalisation:
            label_normalisers = [DiscreteLabelNormalisationLayer(
                image_name='label',
                modalities=vars(task_param).get('label'),
                model_filename=self.net_param.histogram_ref_file)]
            if self.is_evaluation:
                label_normalisers.append(
                    DiscreteLabelNormalisationLayer(
                        image_name='inferred',
                        modalities=vars(task_param).get('inferred'),
                        model_filename=self.net_param.histogram_ref_file))
                label_normalisers[-1].key = label_normalisers[0].key

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation and \
                (self.is_training or not task_param.output_prob):
            normalisation_layers.extend(label_normalisers)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=self.action_param.scaling_percentage[0],
                    max_percentage=self.action_param.scaling_percentage[1]))
            if self.action_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if self.action_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        self.action_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)

            # add deformation layer
            if self.action_param.do_elastic_deformation:
                spatial_rank = list(self.readers[0].spatial_ranks.values())[0]
                augmentation_layers.append(RandomElasticDeformationLayer(
                    spatial_rank=spatial_rank,
                    num_controlpoints=self.action_param.num_ctrl_points,
                    std_deformation_sigma=self.action_param.deformation_sigma,
                    proportion_to_augment=self.action_param.proportion_to_deform))

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(PadLayer(
                image_name=SUPPORTED_INPUT,
                border=self.net_param.volume_padding_size))

        # only add augmentation to first reader (not validation reader)
        self.readers[0].add_preprocessing_layers(
            volume_padding_layer +
            normalisation_layers +
            augmentation_layers)

        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(
                volume_padding_layer +
                normalisation_layers)

    def initialise_uniform_sampler(self):
        self.sampler = [[UniformSampler(
            reader=reader,
            data_param=self.data_param,
            batch_size=self.net_param.batch_size,
            windows_per_image=self.action_param.sample_per_volume,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_weighted_sampler(self):
        self.sampler = [[WeightedSampler(
            reader=reader,
            data_param=self.data_param,
            batch_size=self.net_param.batch_size,
            windows_per_image=self.action_param.sample_per_volume,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_resize_sampler(self):
        self.sampler = [[ResizeSampler(
            reader=reader,
            data_param=self.data_param,
            batch_size=self.net_param.batch_size,
            shuffle_buffer=self.is_training,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_grid_sampler(self):
        self.sampler = [[GridSampler(
            reader=reader,
            data_param=self.data_param,
            batch_size=self.net_param.batch_size,
            spatial_window_size=self.action_param.spatial_window_size,
            window_border=self.action_param.border,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_balanced_sampler(self):
        self.sampler = [[BalancedSampler(
            reader=reader,
            data_param=self.data_param,
            batch_size=self.net_param.batch_size,
            windows_per_image=self.action_param.sample_per_volume,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        elif self.is_inference:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_aggregator(self):
        self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]()

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=self.segmentation_param.num_classes,
            w_initializer=InitializerFactory.get_initializer(
                name=self.net_param.weight_initializer),
            b_initializer=InitializerFactory.get_initializer(
                name=self.net_param.bias_initializer),
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        # def data_net(for_training):
        #    with tf.name_scope('train' if for_training else 'validation'):
        #        sampler = self.get_sampler()[0][0 if for_training else -1]
        #        data_dict = sampler.pop_batch_op()
        #        image = tf.cast(data_dict['image'], tf.float32)
        #        return data_dict, self.net(image, is_training=for_training)

        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            # if self.action_param.validation_every_n > 0:
            #    data_dict, net_out = tf.cond(tf.logical_not(self.is_validation),
            #                                 lambda: data_net(True),
            #                                 lambda: data_net(False))
            # else:
            #    data_dict, net_out = data_net(True)
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type,
                softmax=self.segmentation_param.softmax)
            data_loss = loss_func(
                prediction=net_out,
                ground_truth=data_dict.get('label', None),
                weight_map=data_dict.get('weight', None))
            reg_losses = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(
                var=data_loss, name='loss',
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=data_loss, name='loss',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image*180.0, name='image',
            #    average_over_devices=False, summary_type='image3_sagittal',
            #    collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image, name='image',
            #    average_over_devices=False,
            #    collection=NETWORK_OUTPUT)

            # outputs_collector.add_to_collection(
            #    var=tf.reduce_mean(image), name='mean_image',
            #    average_over_devices=False, summary_type='scalar',
            #    collection=CONSOLE)
        elif self.is_inference:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)

            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)

            outputs_collector.add_to_collection(
                var=net_out, name='window',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'], name='location',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            self.initialise_aggregator()

    def interpret_output(self, batch_output):
        if self.is_inference:
            return self.output_decoder.decode_batch(
                batch_output['window'], batch_output['location'])
        return True

    def initialise_evaluator(self, eval_param):
        self.eval_param = eval_param
        self.evaluator = SegmentationEvaluator(self.readers[0],
                                               self.segmentation_param,
                                               eval_param)

    def add_inferred_output(self, data_param, task_param):
        return self.add_inferred_output_like(data_param, task_param, 'label')
Esempio n. 13
0
class MultiClassifSegApplication(BaseApplication):
    """This class defines an application for image-level classification
    problems mapping from images to scalar labels.

    This is the application class to be instantiated by the driver
    and referred to in configuration files.

    Although structurally similar to segmentation, this application
    supports different samplers/aggregators (because patch-based
    processing is not appropriate), and monitoring metrics."""

    REQUIRED_CONFIG_SECTION = "SEGMENTATION"

    def __init__(self, net_param, action_param, action):
        super(MultiClassifSegApplication, self).__init__()
        tf.logging.info('starting classification application')
        self.action = action

        self.net_param = net_param
        self.eval_param = None
        self.evaluator = None
        self.action_param = action_param
        self.net_multi = None
        self.data_param = None
        self.segmentation_param = None
        self.csv_readers = None
        self.SUPPORTED_SAMPLING = {
            'uniform':
            (self.initialise_uniform_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'weighted':
            (self.initialise_weighted_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'resize':
            (self.initialise_resize_sampler, self.initialise_resize_sampler,
             self.initialise_resize_aggregator),
            'balanced':
            (self.initialise_balanced_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
        }

    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):
        '''
        Initialise the data loader both csv readers and image readers and
        specify preprocessing layers
        :param data_param:
        :param task_param:
        :param data_partitioner:
        :return:
        '''

        self.data_param = data_param
        self.segmentation_param = task_param

        if self.is_training:
            image_reader_names = ('image', 'sampler', 'label')
            csv_reader_names = ('value', )
        elif self.is_inference:
            image_reader_names = ('image', )
            csv_reader_names = ()
        elif self.is_evaluation:
            image_reader_names = ('image', 'inferred', 'label')
            csv_reader_names = ('value', )
        else:
            tf.logging.fatal('Action `%s` not supported. Expected one of %s',
                             self.action, self.SUPPORTED_PHASES)
            raise ValueError
        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(phase=reader_phase,
                                                        action=self.action)
        self.readers = [
            ImageReader(image_reader_names).initialise(data_param, task_param,
                                                       file_list)
            for file_list in file_lists
        ]
        if self.is_inference:
            self.action_param.sample_per_volume = 1
        if csv_reader_names is not None and list(csv_reader_names):
            self.csv_readers = [
                CSVReader(csv_reader_names).initialise(
                    data_param,
                    task_param,
                    file_list,
                    sample_per_volume=self.action_param.sample_per_volume)
                for file_list in file_lists
            ]
        else:
            self.csv_readers = [None for file_list in file_lists]

        foreground_masking_layer = BinaryMaskingLayer(
            type_str=self.net_param.foreground_type,
            multimod_fusion=self.net_param.multimod_foreground_type,
            threshold=0.0) \
            if self.net_param.normalise_foreground_only else None

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer) \
            if self.net_param.whitening else None
        histogram_normaliser = HistogramNormalisationLayer(
            image_name='image',
            modalities=vars(task_param).get('image'),
            model_filename=self.net_param.histogram_ref_file,
            binary_masking_func=foreground_masking_layer,
            norm_type=self.net_param.norm_type,
            cutoff=self.net_param.cutoff,
            name='hist_norm_layer') \
            if (self.net_param.histogram_ref_file and
                self.net_param.normalisation) else None

        label_normaliser = DiscreteLabelNormalisationLayer(
            image_name='label',
            modalities=vars(task_param).get('label'),
            model_filename=self.net_param.histogram_ref_file) \
            if (self.net_param.histogram_ref_file and
                task_param.label_normalisation) else None

        normalisation_layers = []
        if histogram_normaliser is not None:
            normalisation_layers.append(histogram_normaliser)
        if mean_var_normaliser is not None:
            normalisation_layers.append(mean_var_normaliser)
        if label_normaliser is not None:
            normalisation_layers.append(label_normaliser)

        augmentation_layers = []
        if self.is_training:
            train_param = self.action_param
            if train_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=train_param.random_flipping_axes))
            if train_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=train_param.scaling_percentage[0],
                        max_percentage=train_param.scaling_percentage[1]))
            if train_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if train_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        train_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)

        # only add augmentation to first reader (not validation reader)
        self.readers[0].add_preprocessing_layers(normalisation_layers +
                                                 augmentation_layers)
        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(normalisation_layers)

    def initialise_uniform_sampler(self):
        '''
        Create the uniform sampler using information from readers
        :return:
        '''
        self.sampler = [[
            UniformSampler(
                reader=reader,
                csv_reader=csv_reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader, csv_reader in zip(self.readers, self.csv_readers)
        ]]

    def initialise_weighted_sampler(self):
        '''
        Create the weighted sampler using the info from the csv_readers and
        image_readers and the configuration parameters
        :return:
        '''
        self.sampler = [[
            WeightedSampler(
                reader=reader,
                csv_reader=csv_reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader, csv_reader in zip(self.readers, self.csv_readers)
        ]]

    def initialise_resize_sampler(self):
        '''
        Define the resize sampler using the information from the
        configuration parameters, csv_readers and image_readers
        :return:
        '''
        self.sampler = [[
            ResizeSampler(reader=reader,
                          csv_reader=csv_reader,
                          window_sizes=self.data_param,
                          batch_size=self.net_param.batch_size,
                          shuffle=self.is_training,
                          smaller_final_batch_mode=self.net_param.
                          smaller_final_batch_mode,
                          queue_length=self.net_param.queue_length)
            for reader, csv_reader in zip(self.readers, self.csv_readers)
        ]]

    def initialise_grid_sampler(self):
        '''
        Define the grid sampler based on the information from configuration
        and the csv_readers and image_readers specifications
        :return:
        '''
        self.sampler = [[
            GridSampler(
                reader=reader,
                csv_reader=csv_reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                spatial_window_size=self.action_param.spatial_window_size,
                window_border=self.action_param.border,
                smaller_final_batch_mode=self.net_param.
                smaller_final_batch_mode,
                queue_length=self.net_param.queue_length)
            for reader, csv_reader in zip(self.readers, self.csv_readers)
        ]]

    def initialise_balanced_sampler(self):
        '''
        Define the balanced sampler based on the information from configuration
        and the csv_readers and image_readers specifications
        :return:
        '''
        self.sampler = [[
            BalancedSampler(
                reader=reader,
                csv_reader=csv_reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader, csv_reader in zip(self.readers, self.csv_readers)
        ]]

    def initialise_grid_aggregator(self):
        '''
        Define the grid aggregator used for decoding using configuration
        parameters
        :return:
        '''
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order,
            postfix=self.action_param.output_postfix,
            fill_constant=self.action_param.fill_constant)

    def initialise_resize_aggregator(self):
        '''
        Define the resize aggregator used for decoding using the
        configuration parameters
        :return:
        '''
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order,
            postfix=self.action_param.output_postfix)

    def initialise_sampler(self):
        '''
        Specifies the sampler used among those previously defined based on
        the sampling choice
        :return:
        '''
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        elif self.is_inference:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_aggregator(self):
        '''
        Specifies the aggregator used based on the sampling choice
        :return:
        '''
        self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]()

    def initialise_network(self):
        '''
        Initialise the network and specifies the ordering of elements
        :return:
        '''
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(
            'niftynet.contrib.csv_reader.toynet_features.ToyNetFeat')(
                num_classes=self.segmentation_param.num_classes,
                w_initializer=InitializerFactory.get_initializer(
                    name=self.net_param.weight_initializer),
                b_initializer=InitializerFactory.get_initializer(
                    name=self.net_param.bias_initializer),
                w_regularizer=w_regularizer,
                b_regularizer=b_regularizer,
                acti_func=self.net_param.activation_function)
        self.net_multi = ApplicationNetFactory.create(
            'niftynet.contrib.csv_reader.class_seg_finnet.ClassSegFinnet')(
                num_classes=self.segmentation_param.num_classes,
                w_initializer=InitializerFactory.get_initializer(
                    name=self.net_param.weight_initializer),
                b_initializer=InitializerFactory.get_initializer(
                    name=self.net_param.bias_initializer),
                w_regularizer=w_regularizer,
                b_regularizer=b_regularizer,
                acti_func=self.net_param.activation_function)

    def add_confusion_matrix_summaries_(self, outputs_collector, net_out,
                                        data_dict):
        """ This method defines several monitoring metrics that
        are derived from the confusion matrix """
        labels = tf.reshape(tf.cast(data_dict['label'], tf.int64), [-1])
        prediction = tf.reshape(tf.argmax(net_out, -1), [-1])
        num_classes = 2
        conf_mat = tf.confusion_matrix(labels, prediction, num_classes)
        conf_mat = tf.to_float(conf_mat)
        if self.segmentation_param.num_classes == 2:
            outputs_collector.add_to_collection(var=conf_mat[1][1],
                                                name='true_positives',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=conf_mat[1][0],
                                                name='false_negatives',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=conf_mat[0][1],
                                                name='false_positives',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=conf_mat[0][0],
                                                name='true_negatives',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
        else:
            outputs_collector.add_to_collection(var=conf_mat[tf.newaxis, :, :,
                                                             tf.newaxis],
                                                name='confusion_matrix',
                                                average_over_devices=True,
                                                summary_type='image',
                                                collection=TF_SUMMARIES)

        outputs_collector.add_to_collection(var=tf.trace(conf_mat),
                                            name='accuracy',
                                            average_over_devices=True,
                                            summary_type='scalar',
                                            collection=TF_SUMMARIES)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)
            net_out_seg, net_out_class = self.net_multi(
                net_out, self.is_training)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func_class = LossFunctionClassification(
                n_class=2, loss_type='CrossEntropy')
            loss_func_seg = LossFunctionSegmentation(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type)
            data_loss_seg = loss_func_seg(prediction=net_out_seg,
                                          ground_truth=data_dict.get(
                                              'label', None))
            data_loss_class = loss_func_class(prediction=net_out_class,
                                              ground_truth=data_dict.get(
                                                  'value', None))
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss_seg + data_loss_class + reg_loss
            else:
                loss = data_loss_seg + data_loss_class
            self.total_loss = loss
            self.total_loss = tf.Print(
                tf.cast(self.total_loss, tf.float32),
                [loss, tf.shape(net_out_seg),
                 tf.shape(net_out_class)],
                message='test')
            grads = self.optimiser.compute_gradients(
                loss, colocate_gradients_with_ops=True)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss_class,
                                                name='data_loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss_seg,
                                                name='data_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            # self.add_confusion_matrix_summaries_(outputs_collector,
            #                                      net_out_class,
            #                                      data_dict)
        else:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)
            net_out_seg, net_out_class = self.net_multi(
                net_out, self.is_training)
            tf.logging.info('net_out.shape may need to be resized: %s',
                            net_out.shape)
            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer_class = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
                post_process_layer_seg = PostProcessingLayer('SOFTMAX',
                                                             num_classes=2)
            elif not output_prob and num_classes > 1:
                post_process_layer_class = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
                post_process_layer_seg = PostProcessingLayer('ARGMAX',
                                                             num_classes=2)
            else:
                post_process_layer_class = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
                post_process_layer_seg = PostProcessingLayer('IDENTITY',
                                                             num_classes=2)

            net_out_class = post_process_layer_class(net_out_class)
            net_out_seg = post_process_layer_seg(net_out_seg)

            outputs_collector.add_to_collection(var=net_out_seg,
                                                name='seg',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=net_out_class,
                                                name='value',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            self.initialise_aggregator()

    def interpret_output(self, batch_output):
        '''
        Specifies how the output should be decoded
        :param batch_output:
        :return:
        '''
        if not self.is_training:
            return self.output_decoder.decode_batch(
                {
                    'window_seg': batch_output['seg'],
                    'csv_class': batch_output['value']
                }, batch_output['location'])
        return True

    def initialise_evaluator(self, eval_param):
        '''
        Define the evaluator
        :param eval_param:
        :return:
        '''
        self.eval_param = eval_param
        self.evaluator = ClassificationEvaluator(self.readers[0],
                                                 self.segmentation_param,
                                                 eval_param)

    def add_inferred_output(self, data_param, task_param):
        '''
        Define how to treat added inferred output
        :param data_param:
        :param task_param:
        :return:
        '''
        return self.add_inferred_output_like(data_param, task_param, 'label')
Esempio n. 14
0
class RegApp(BaseApplication):

    REQUIRED_CONFIG_SECTION = "REGISTRATION"

    def __init__(self, net_param, action_param, action):
        BaseApplication.__init__(self)
        tf.logging.info('starting label-driven registration')
        self.action = action

        self.net_param = net_param
        self.action_param = action_param

        self.registration_param = None
        self.data_param = None

    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):
        self.data_param = data_param
        self.registration_param = task_param

        if self.is_evaluation:
            NotImplementedError('Evaluation is not yet '
                                'supported in this application.')
        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(phase=reader_phase,
                                                        action=self.action)

        self.readers = []
        for file_list in file_lists:
            fixed_reader = ImageReader({'fixed_image', 'fixed_label'})
            fixed_reader.initialise(data_param, task_param, file_list)
            self.readers.append(fixed_reader)

            moving_reader = ImageReader({'moving_image', 'moving_label'})
            moving_reader.initialise(data_param, task_param, file_list)
            self.readers.append(moving_reader)

        # pad the fixed target only
        # moving image will be resampled to match the targets
        #volume_padding_layer = []
        #if self.net_param.volume_padding_size:
        #    volume_padding_layer.append(PadLayer(
        #        image_name=('fixed_image', 'fixed_label'),
        #        border=self.net_param.volume_padding_size))

        #for reader in self.readers:
        #    reader.add_preprocessing_layers(volume_padding_layer)

    def initialise_sampler(self):
        if self.is_training:
            self.sampler = []
            assert len(self.readers) >= 2, 'at least two readers are required'
            training_sampler = PairwiseUniformSampler(
                reader_0=self.readers[0],
                reader_1=self.readers[1],
                data_param=self.data_param,
                batch_size=self.net_param.batch_size)
            self.sampler.append(training_sampler)
            # adding validation readers if possible
            if len(self.readers) >= 4:
                validation_sampler = PairwiseUniformSampler(
                    reader_0=self.readers[2],
                    reader_1=self.readers[3],
                    data_param=self.data_param,
                    batch_size=self.net_param.batch_size)
                self.sampler.append(validation_sampler)
        else:
            self.sampler = PairwiseResizeSampler(
                reader_0=self.readers[0],
                reader_1=self.readers[1],
                data_param=self.data_param,
                batch_size=self.net_param.batch_size)

    def initialise_network(self):
        decay = self.net_param.decay
        self.net = ApplicationNetFactory.create(self.net_param.name)(decay)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_samplers(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0 if for_training else -1]
                return sampler()  # returns image only

        if self.is_training:
            self.patience = self.action_param.patience
            self.mode = self.action_param.early_stopping_mode
            if self.action_param.validation_every_n > 0:
                sampler_window = \
                    tf.cond(tf.logical_not(self.is_validation),
                            lambda: switch_samplers(True),
                            lambda: switch_samplers(False))
            else:
                sampler_window = switch_samplers(True)

            image_windows, _ = sampler_window
            # image_windows, locations = sampler_window

            # decode channels for moving and fixed images
            image_windows_list = [
                tf.expand_dims(img, axis=-1)
                for img in tf.unstack(image_windows, axis=-1)
            ]
            fixed_image, fixed_label, moving_image, moving_label = \
                image_windows_list

            # estimate ddf
            dense_field = self.net(fixed_image, moving_image)
            if isinstance(dense_field, tuple):
                dense_field = dense_field[0]

            # transform the moving labels
            resampler = ResamplerLayer(interpolation='linear',
                                       boundary='replicate')
            resampled_moving_label = resampler(moving_label, dense_field)

            # compute label loss (foreground only)
            loss_func = LossFunction(n_class=1,
                                     loss_type=self.action_param.loss_type,
                                     softmax=False)
            label_loss = loss_func(prediction=resampled_moving_label,
                                   ground_truth=fixed_label)

            dice_fg = 1.0 - label_loss
            # appending regularisation loss
            total_loss = label_loss
            reg_loss = tf.get_collection('bending_energy')
            if reg_loss:
                total_loss = total_loss + \
                    self.net_param.decay * tf.reduce_mean(reg_loss)

            self.total_loss = total_loss

            # compute training gradients
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            grads = self.optimiser.compute_gradients(
                total_loss, colocate_gradients_with_ops=True)
            gradients_collector.add_to_collection(grads)

            metrics_dice = loss_func(
                prediction=tf.to_float(resampled_moving_label >= 0.5),
                ground_truth=tf.to_float(fixed_label >= 0.5))
            metrics_dice = 1.0 - metrics_dice

            # command line output
            outputs_collector.add_to_collection(var=dice_fg,
                                                name='one_minus_data_loss',
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=tf.reduce_mean(reg_loss),
                                                name='bending_energy',
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=total_loss,
                                                name='total_loss',
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=metrics_dice,
                                                name='ave_fg_dice',
                                                collection=CONSOLE)

            # for tensorboard
            outputs_collector.add_to_collection(var=dice_fg,
                                                name='data_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=total_loss,
                                                name='total_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=metrics_dice,
                name='averaged_foreground_Dice',
                average_over_devices=True,
                summary_type='scalar',
                collection=TF_SUMMARIES)

            # for visualisation debugging
            # resampled_moving_image = resampler(moving_image, dense_field)
            # outputs_collector.add_to_collection(
            #     var=fixed_image, name='fixed_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=fixed_label, name='fixed_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=moving_image, name='moving_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=moving_label, name='moving_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=resampled_moving_image, name='resampled_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=resampled_moving_label, name='resampled_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=dense_field, name='ddf', collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=locations, name='locations', collection=NETWORK_OUTPUT)

            # outputs_collector.add_to_collection(
            #     var=shift[0], name='a', collection=CONSOLE)
            # outputs_collector.add_to_collection(
            #     var=shift[1], name='b', collection=CONSOLE)
        else:
            image_windows, locations = self.sampler()
            image_windows_list = [
                tf.expand_dims(img, axis=-1)
                for img in tf.unstack(image_windows, axis=-1)
            ]
            fixed_image, fixed_label, moving_image, moving_label = \
                image_windows_list

            dense_field = self.net(fixed_image, moving_image)
            if isinstance(dense_field, tuple):
                dense_field = dense_field[0]

            # transform the moving labels
            resampler = ResamplerLayer(interpolation='linear',
                                       boundary='replicate')
            resampled_moving_image = resampler(moving_image, dense_field)
            resampled_moving_label = resampler(moving_label, dense_field)

            outputs_collector.add_to_collection(var=fixed_image,
                                                name='fixed_image',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=moving_image,
                                                name='moving_image',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=resampled_moving_image,
                                                name='resampled_moving_image',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=resampled_moving_label,
                                                name='resampled_moving_label',
                                                collection=NETWORK_OUTPUT)

            outputs_collector.add_to_collection(var=fixed_label,
                                                name='fixed_label',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=moving_label,
                                                name='moving_label',
                                                collection=NETWORK_OUTPUT)
            #outputs_collector.add_to_collection(
            #    var=dense_field, name='field',
            #    collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=locations,
                                                name='locations',
                                                collection=NETWORK_OUTPUT)

            self.output_decoder = ResizeSamplesAggregator(
                image_reader=self.readers[0],  # fixed image reader
                name='fixed_image',
                output_path=self.action_param.save_seg_dir,
                interp_order=self.action_param.output_interp_order)

    def interpret_output(self, batch_output):
        if self.is_training:
            return True
        return self.output_decoder.decode_batch(
            {'window_resampled': batch_output['resampled_moving_image']},
            batch_output['locations'])
class RegressionApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "REGRESSION"

    def __init__(self, net_param, action_param, is_training):
        BaseApplication.__init__(self)
        tf.logging.info('starting regression application')
        self.is_training = is_training

        self.net_param = net_param
        self.action_param = action_param
        self.regression_param = None

        self.data_param = None
        self.SUPPORTED_SAMPLING = {
            'uniform':
            (self.initialise_uniform_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'weighted':
            (self.initialise_weighted_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'resize':
            (self.initialise_resize_sampler, self.initialise_resize_sampler,
             self.initialise_resize_aggregator),
        }

    def initialise_dataset_loader(self, data_param=None, task_param=None):
        self.data_param = data_param
        self.regression_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.reader = ImageReader(SUPPORTED_INPUT)
        else:  # in the inference process use image input only
            self.reader = ImageReader(['image'])
        self.reader.initialise_reader(data_param, task_param)

        mean_var_normaliser = MeanVarNormalisationLayer(image_name='image')
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')
        else:
            histogram_normaliser = None

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size))
        self.reader.add_preprocessing_layers(volume_padding_layer +
                                             normalisation_layers +
                                             augmentation_layers)

    def initialise_uniform_sampler(self):
        self.sampler = [
            UniformSampler(
                reader=self.reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
        ]

    def initialise_weighted_sampler(self):
        self.sampler = [
            WeightedSampler(
                reader=self.reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
        ]

    def initialise_resize_sampler(self):
        self.sampler = [
            ResizeSampler(reader=self.reader,
                          data_param=self.data_param,
                          batch_size=self.net_param.batch_size,
                          shuffle_buffer=self.is_training,
                          queue_length=self.net_param.queue_length)
        ]

    def initialise_grid_sampler(self):
        self.sampler = [
            GridSampler(
                reader=self.reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                spatial_window_size=self.action_param.spatial_window_size,
                window_border=self.action_param.border,
                queue_length=self.net_param.queue_length)
        ]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.reader,
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.reader,
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        else:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=1,
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        data_dict = self.get_sampler()[0].pop_batch_op()
        image = tf.cast(data_dict['image'], tf.float32)
        net_out = self.net(image, self.is_training)

        if self.is_training:
            crop_layer = CropLayer(border=self.regression_param.loss_border,
                                   name='crop-88')
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(loss_type=self.action_param.loss_type)

            prediction = crop_layer(net_out)
            ground_truth = crop_layer(data_dict.get('output', None))
            weight_map = None if data_dict.get('weight', None) is None \
                else crop_layer(data_dict.get('weight', None))
            data_loss = loss_func(prediction=prediction,
                                  ground_truth=ground_truth,
                                  weight_map=weight_map)

            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='Loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='Loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
        else:
            crop_layer = CropLayer(border=0, name='crop-88')
            post_process_layer = PostProcessingLayer('IDENTITY')
            net_out = post_process_layer(crop_layer(net_out))

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            init_aggregator = \
                self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
            init_aggregator()

    def interpret_output(self, batch_output):
        if not self.is_training:
            return self.output_decoder.decode_batch(batch_output['window'],
                                                    batch_output['location'])
        else:
            return True
 def initialise_grid_aggregator(self):
     self.output_decoder = GridSamplesAggregator(
         image_reader=self.reader,
         output_path=self.action_param.save_seg_dir,
         window_border=self.action_param.border,
         interp_order=self.action_param.output_interp_order)
Esempio n. 17
0
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_samplers(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0 if for_training else -1]
                return sampler()  # returns image only

        if self.is_training:
            self.patience = self.action_param.patience
            self.mode = self.action_param.early_stopping_mode
            if self.action_param.validation_every_n > 0:
                sampler_window = \
                    tf.cond(tf.logical_not(self.is_validation),
                            lambda: switch_samplers(True),
                            lambda: switch_samplers(False))
            else:
                sampler_window = switch_samplers(True)

            image_windows, _ = sampler_window
            # image_windows, locations = sampler_window

            # decode channels for moving and fixed images
            image_windows_list = [
                tf.expand_dims(img, axis=-1)
                for img in tf.unstack(image_windows, axis=-1)
            ]
            fixed_image, fixed_label, moving_image, moving_label = \
                image_windows_list

            # estimate ddf
            dense_field = self.net(fixed_image, moving_image)
            if isinstance(dense_field, tuple):
                dense_field = dense_field[0]

            # transform the moving labels
            resampler = ResamplerLayer(interpolation='linear',
                                       boundary='replicate')
            resampled_moving_label = resampler(moving_label, dense_field)

            # compute label loss (foreground only)
            loss_func = LossFunction(n_class=1,
                                     loss_type=self.action_param.loss_type,
                                     softmax=False)
            label_loss = loss_func(prediction=resampled_moving_label,
                                   ground_truth=fixed_label)

            dice_fg = 1.0 - label_loss
            # appending regularisation loss
            total_loss = label_loss
            reg_loss = tf.get_collection('bending_energy')
            if reg_loss:
                total_loss = total_loss + \
                    self.net_param.decay * tf.reduce_mean(reg_loss)

            self.total_loss = total_loss

            # compute training gradients
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            grads = self.optimiser.compute_gradients(
                total_loss, colocate_gradients_with_ops=True)
            gradients_collector.add_to_collection(grads)

            metrics_dice = loss_func(
                prediction=tf.to_float(resampled_moving_label >= 0.5),
                ground_truth=tf.to_float(fixed_label >= 0.5))
            metrics_dice = 1.0 - metrics_dice

            # command line output
            outputs_collector.add_to_collection(var=dice_fg,
                                                name='one_minus_data_loss',
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=tf.reduce_mean(reg_loss),
                                                name='bending_energy',
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=total_loss,
                                                name='total_loss',
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=metrics_dice,
                                                name='ave_fg_dice',
                                                collection=CONSOLE)

            # for tensorboard
            outputs_collector.add_to_collection(var=dice_fg,
                                                name='data_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=total_loss,
                                                name='total_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=metrics_dice,
                name='averaged_foreground_Dice',
                average_over_devices=True,
                summary_type='scalar',
                collection=TF_SUMMARIES)

            # for visualisation debugging
            # resampled_moving_image = resampler(moving_image, dense_field)
            # outputs_collector.add_to_collection(
            #     var=fixed_image, name='fixed_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=fixed_label, name='fixed_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=moving_image, name='moving_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=moving_label, name='moving_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=resampled_moving_image, name='resampled_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=resampled_moving_label, name='resampled_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=dense_field, name='ddf', collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=locations, name='locations', collection=NETWORK_OUTPUT)

            # outputs_collector.add_to_collection(
            #     var=shift[0], name='a', collection=CONSOLE)
            # outputs_collector.add_to_collection(
            #     var=shift[1], name='b', collection=CONSOLE)
        else:
            image_windows, locations = self.sampler()
            image_windows_list = [
                tf.expand_dims(img, axis=-1)
                for img in tf.unstack(image_windows, axis=-1)
            ]
            fixed_image, fixed_label, moving_image, moving_label = \
                image_windows_list

            dense_field = self.net(fixed_image, moving_image)
            if isinstance(dense_field, tuple):
                dense_field = dense_field[0]

            # transform the moving labels
            resampler = ResamplerLayer(interpolation='linear',
                                       boundary='replicate')
            resampled_moving_image = resampler(moving_image, dense_field)
            resampled_moving_label = resampler(moving_label, dense_field)

            outputs_collector.add_to_collection(var=fixed_image,
                                                name='fixed_image',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=moving_image,
                                                name='moving_image',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=resampled_moving_image,
                                                name='resampled_moving_image',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=resampled_moving_label,
                                                name='resampled_moving_label',
                                                collection=NETWORK_OUTPUT)

            outputs_collector.add_to_collection(var=fixed_label,
                                                name='fixed_label',
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=moving_label,
                                                name='moving_label',
                                                collection=NETWORK_OUTPUT)
            #outputs_collector.add_to_collection(
            #    var=dense_field, name='field',
            #    collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(var=locations,
                                                name='locations',
                                                collection=NETWORK_OUTPUT)

            self.output_decoder = ResizeSamplesAggregator(
                image_reader=self.readers[0],  # fixed image reader
                name='fixed_image',
                output_path=self.action_param.save_seg_dir,
                interp_order=self.action_param.output_interp_order)
class RegApp(BaseApplication):

    REQUIRED_CONFIG_SECTION = "REGISTRATION"

    def __init__(self, net_param, action_param, action):
        BaseApplication.__init__(self)
        tf.logging.info('starting label-driven registration')
        self.action = action

        self.net_param = net_param
        self.action_param = action_param

        self.registration_param = None
        self.data_param = None

    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):
        self.data_param = data_param
        self.registration_param = task_param

        file_lists = self.get_file_lists(data_partitioner)

        if self.is_evaluation:
            NotImplementedError('Evaluation is not yet '
                                'supported in this application.')

        self.readers = []
        for file_list in file_lists:
            fixed_reader = ImageReader({'fixed_image', 'fixed_label'})
            fixed_reader.initialise(data_param, task_param, file_list)
            self.readers.append(fixed_reader)

            moving_reader = ImageReader({'moving_image', 'moving_label'})
            moving_reader.initialise(data_param, task_param, file_list)
            self.readers.append(moving_reader)

        # pad the fixed target only
        # moving image will be resampled to match the targets
        #volume_padding_layer = []
        #if self.net_param.volume_padding_size:
        #    volume_padding_layer.append(PadLayer(
        #        image_name=('fixed_image', 'fixed_label'),
        #        border=self.net_param.volume_padding_size))

        #for reader in self.readers:
        #    reader.add_preprocessing_layers(volume_padding_layer)


    def initialise_sampler(self):
        if self.is_training:
            self.sampler = []
            assert len(self.readers) >= 2, 'at least two readers are required'
            training_sampler = PairwiseUniformSampler(
                reader_0=self.readers[0],
                reader_1=self.readers[1],
                data_param=self.data_param,
                batch_size=self.net_param.batch_size)
            self.sampler.append(training_sampler)
            # adding validation readers if possible
            if len(self.readers) >= 4:
                validation_sampler = PairwiseUniformSampler(
                    reader_0=self.readers[2],
                    reader_1=self.readers[3],
                    data_param=self.data_param,
                    batch_size=self.net_param.batch_size)
                self.sampler.append(validation_sampler)
        else:
            self.sampler = PairwiseResizeSampler(
                reader_0=self.readers[0],
                reader_1=self.readers[1],
                data_param=self.data_param,
                batch_size=self.net_param.batch_size)

    def initialise_network(self):
        decay = self.net_param.decay
        self.net = ApplicationNetFactory.create(self.net_param.name)(decay)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):

        def switch_samplers(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0 if for_training else -1]
                return sampler()  # returns image only

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                sampler_window = \
                    tf.cond(tf.logical_not(self.is_validation),
                            lambda: switch_samplers(True),
                            lambda: switch_samplers(False))
            else:
                sampler_window = switch_samplers(True)

            image_windows, _ = sampler_window
            # image_windows, locations = sampler_window

            # decode channels for moving and fixed images
            image_windows_list = [
                tf.expand_dims(img, axis=-1)
                for img in tf.unstack(image_windows, axis=-1)]
            fixed_image, fixed_label, moving_image, moving_label = \
                image_windows_list

            # estimate ddf
            dense_field = self.net(fixed_image, moving_image)
            if isinstance(dense_field, tuple):
                dense_field = dense_field[0]

            # transform the moving labels
            resampler = ResamplerLayer(
                interpolation='linear', boundary='replicate')
            resampled_moving_label = resampler(moving_label, dense_field)

            # compute label loss (foreground only)
            loss_func = LossFunction(
                n_class=1,
                loss_type=self.action_param.loss_type,
                softmax=False)
            label_loss = loss_func(prediction=resampled_moving_label,
                                   ground_truth=fixed_label)

            dice_fg = 1.0 - label_loss
            # appending regularisation loss
            total_loss = label_loss
            reg_loss = tf.get_collection('bending_energy')
            if reg_loss:
                total_loss = total_loss + \
                    self.net_param.decay * tf.reduce_mean(reg_loss)

            # compute training gradients
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            grads = self.optimiser.compute_gradients(total_loss)
            gradients_collector.add_to_collection(grads)

            metrics_dice = loss_func(
                prediction=tf.to_float(resampled_moving_label >= 0.5),
                ground_truth=tf.to_float(fixed_label >= 0.5))
            metrics_dice = 1.0 - metrics_dice

            # command line output
            outputs_collector.add_to_collection(
                var=dice_fg, name='one_minus_data_loss',
                collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=tf.reduce_mean(reg_loss), name='bending_energy',
                collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=total_loss, name='total_loss', collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=metrics_dice, name='ave_fg_dice', collection=CONSOLE)

            # for tensorboard
            outputs_collector.add_to_collection(
                var=dice_fg,
                name='data_loss',
                average_over_devices=True,
                summary_type='scalar',
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=total_loss,
                name='averaged_total_loss',
                average_over_devices=True,
                summary_type='scalar',
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=metrics_dice,
                name='averaged_foreground_Dice',
                average_over_devices=True,
                summary_type='scalar',
                collection=TF_SUMMARIES)

            # for visualisation debugging
            # resampled_moving_image = resampler(moving_image, dense_field)
            # outputs_collector.add_to_collection(
            #     var=fixed_image, name='fixed_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=fixed_label, name='fixed_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=moving_image, name='moving_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=moving_label, name='moving_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=resampled_moving_image, name='resampled_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=resampled_moving_label, name='resampled_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=dense_field, name='ddf', collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=locations, name='locations', collection=NETWORK_OUTPUT)

            # outputs_collector.add_to_collection(
            #     var=shift[0], name='a', collection=CONSOLE)
            # outputs_collector.add_to_collection(
            #     var=shift[1], name='b', collection=CONSOLE)
        else:
            image_windows, locations = self.sampler()
            image_windows_list = [
                tf.expand_dims(img, axis=-1)
                for img in tf.unstack(image_windows, axis=-1)]
            fixed_image, fixed_label, moving_image, moving_label = \
                image_windows_list

            dense_field = self.net(fixed_image, moving_image)
            if isinstance(dense_field, tuple):
                dense_field = dense_field[0]

            # transform the moving labels
            resampler = ResamplerLayer(
                interpolation='linear', boundary='replicate')
            resampled_moving_image = resampler(moving_image, dense_field)
            resampled_moving_label = resampler(moving_label, dense_field)

            outputs_collector.add_to_collection(
                var=fixed_image, name='fixed_image',
                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=moving_image, name='moving_image',
                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=resampled_moving_image,
                name='resampled_moving_image',
                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=resampled_moving_label,
                name='resampled_moving_label',
                collection=NETWORK_OUTPUT)

            outputs_collector.add_to_collection(
                var=fixed_label, name='fixed_label',
                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=moving_label, name='moving_label',
                collection=NETWORK_OUTPUT)
            #outputs_collector.add_to_collection(
            #    var=dense_field, name='field',
            #    collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=locations, name='locations',
                collection=NETWORK_OUTPUT)

            self.output_decoder = ResizeSamplesAggregator(
                image_reader=self.readers[0], # fixed image reader
                name='fixed_image',
                output_path=self.action_param.save_seg_dir,
                interp_order=self.action_param.output_interp_order)

    def interpret_output(self, batch_output):
        if self.is_training:
            return True
        return self.output_decoder.decode_batch(
            batch_output['resampled_moving_image'],
            batch_output['locations'])
class SegmentationApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "SEGMENTATION"

    def __init__(self, net_param, action_param, is_training):
        BaseApplication.__init__(self)
        tf.logging.info('starting segmentation application')
        self.is_training = is_training

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.segmentation_param = None
        self.SUPPORTED_SAMPLING = {
            'uniform':
            (self.initialise_uniform_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'resize':
            (self.initialise_resize_sampler, self.initialise_resize_sampler,
             self.initialise_resize_aggregator),
        }

    def initialise_dataset_loader(self, data_param=None, task_param=None):
        self.data_param = data_param
        self.segmentation_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.reader = ImageReader(SUPPORTED_INPUT)
        else:  # in the inference process use image input only
            self.reader = ImageReader(['image'])
        self.reader.initialise_reader(data_param, task_param)

        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)
        else:
            foreground_masking_layer = None

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')
        else:
            histogram_normaliser = None

        if self.net_param.histogram_ref_file:
            label_normaliser = DiscreteLabelNormalisationLayer(
                image_name='label',
                modalities=vars(task_param).get('label'),
                model_filename=self.net_param.histogram_ref_file)
        else:
            label_normaliser = None

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation:
            normalisation_layers.append(label_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.rotation_angle:
                rotation_layer = RandomRotationLayer()
                if self.action_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        self.action_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)
# ========================== Disable scaling and rotation =====================
#             if self.action_param.scaling_percentage:
#                 augmentation_layers.append(RandomSpatialScalingLayer(
#                     min_percentage=self.action_param.scaling_percentage[0],
#                     max_percentage=self.action_param.scaling_percentage[1]))
# =============================================================================
# =============================================================================
#             if self.action_param.rotation_angle:
#                 augmentation_layers.append(RandomRotationLayer(
#                     min_angle=self.action_param.rotation_angle[0],
#                     max_angle=self.action_param.rotation_angle[1]))
# =============================================================================

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size))
        self.reader.add_preprocessing_layers(volume_padding_layer +
                                             normalisation_layers +
                                             augmentation_layers)

    def initialise_uniform_sampler(self):
        self.sampler = [
            UniformSampler(
                reader=self.reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
        ]

    def initialise_resize_sampler(self):
        self.sampler = [
            ResizeSampler(reader=self.reader,
                          data_param=self.data_param,
                          batch_size=self.net_param.batch_size,
                          shuffle_buffer=self.is_training,
                          queue_length=self.net_param.queue_length)
        ]

    def initialise_grid_sampler(self):
        self.sampler = [
            GridSampler(
                reader=self.reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                spatial_window_size=self.action_param.spatial_window_size,
                window_border=self.action_param.border,
                queue_length=self.net_param.queue_length)
        ]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.reader,
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.reader,
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        else:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_network(self):
        num_classes = self.segmentation_param.num_classes
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=num_classes,
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        data_dict = self.get_sampler()[0].pop_batch_op()
        image = tf.cast(data_dict['image'], tf.float32)
        net_out = self.net(image, self.is_training)

        if self.is_training:

            label = data_dict.get('label', None)

            # Changed label on 11/29/2017: This will generate a 2D label
            # from the 3D label provided in the input. Only suitable for STNeuroNet
            k = label.get_shape().as_list()
            label = tf.nn.max_pool3d(label, [1, 1, 1, k[3], 1],
                                     [1, 1, 1, 1, 1],
                                     'VALID',
                                     data_format='NDHWC')
            print('label shape is{}'.format(label.get_shape()))
            print('Image shape is{}'.format(image.get_shape()))
            print('Out shape is{}'.format(net_out.get_shape()))
            ####
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type)
            data_loss = loss_func(prediction=net_out,
                                  ground_truth=label,
                                  weight_map=data_dict.get('weight', None))
            if self.net_param.decay > 0.0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                    loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='dice_loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='dice_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            # ADDED on 10/30 by Soltanian-Zadeh for tensorboard visualization
            seg_summary = tf.to_float(
                tf.expand_dims(tf.argmax(net_out, -1), -1)) * (
                    255. / self.segmentation_param.num_classes - 1)
            label_summary = tf.to_float(tf.expand_dims(
                label, -1)) * (255. / self.segmentation_param.num_classes - 1)
            m, v = tf.nn.moments(image, axes=[1, 2, 3], keep_dims=True)
            img_summary = tf.minimum(
                255.,
                tf.maximum(0., (tf.to_float(image - m) /
                                (tf.sqrt(v) * 2.) + 1.) * 127.))
            image3_axial('img', img_summary, 50, [tf.GraphKeys.SUMMARIES])
            image3_axial('seg', seg_summary, 5, [tf.GraphKeys.SUMMARIES])
            image3_axial('label', label_summary, 5, [tf.GraphKeys.SUMMARIES])
        else:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)
            print('output shape is{}'.format(net_out.get_shape()))

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETORK_OUTPUT)
            init_aggregator = \
                self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
            init_aggregator()

    def interpret_output(self, batch_output):
        if not self.is_training:
            return self.output_decoder.decode_batch(batch_output['window'],
                                                    batch_output['location'])
        return True
Esempio n. 20
0
class ClassificationApplication(BaseApplication):
    """This class defines an application for image-level classification
    problems mapping from images to scalar labels.

    This is the application class to be instantiated by the driver
    and referred to in configuration files.

    Although structurally similar to segmentation, this application
    supports different samplers/aggregators (because patch-based
    processing is not appropriate), and monitoring metrics."""

    REQUIRED_CONFIG_SECTION = "CLASSIFICATION"

    def __init__(self, net_param, action_param, action):
        super(ClassificationApplication, self).__init__()
        tf.logging.info('starting classification application')
        self.action = action

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.classification_param = None
        self.SUPPORTED_SAMPLING = {
            'resize':
            (self.initialise_resize_sampler, self.initialise_resize_sampler),
        }

    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):

        self.data_param = data_param
        self.classification_param = task_param

        if self.is_training:
            reader_names = ('image', 'label', 'sampler')
        elif self.is_inference:
            reader_names = ('image', )
        elif self.is_evaluation:
            reader_names = ('image', 'label', 'inferred')
        else:
            tf.logging.fatal('Action `%s` not supported. Expected one of %s',
                             self.action, self.SUPPORTED_PHASES)
            raise ValueError
        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(phase=reader_phase,
                                                        action=self.action)
        self.readers = [
            ImageReader(reader_names).initialise(data_param, task_param,
                                                 file_list)
            for file_list in file_lists
        ]

        foreground_masking_layer = BinaryMaskingLayer(
            type_str=self.net_param.foreground_type,
            multimod_fusion=self.net_param.multimod_foreground_type,
            threshold=0.0) \
            if self.net_param.normalise_foreground_only else None

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer) \
            if self.net_param.whitening else None
        histogram_normaliser = HistogramNormalisationLayer(
            image_name='image',
            modalities=vars(task_param).get('image'),
            model_filename=self.net_param.histogram_ref_file,
            binary_masking_func=foreground_masking_layer,
            norm_type=self.net_param.norm_type,
            cutoff=self.net_param.cutoff,
            name='hist_norm_layer') \
            if (self.net_param.histogram_ref_file and
                self.net_param.normalisation) else None

        label_normaliser = DiscreteLabelNormalisationLayer(
            image_name='label',
            modalities=vars(task_param).get('label'),
            model_filename=self.net_param.histogram_ref_file) \
            if (self.net_param.histogram_ref_file and
                task_param.label_normalisation) else None

        normalisation_layers = []
        if histogram_normaliser is not None:
            normalisation_layers.append(histogram_normaliser)
        if mean_var_normaliser is not None:
            normalisation_layers.append(mean_var_normaliser)
        if label_normaliser is not None:
            normalisation_layers.append(label_normaliser)

        augmentation_layers = []
        if self.is_training:
            train_param = self.action_param
            if train_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=train_param.random_flipping_axes))
            if train_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=train_param.scaling_percentage[0],
                        max_percentage=train_param.scaling_percentage[1],
                        antialiasing=train_param.antialiasing,
                        isotropic=train_param.isotropic_scaling))
            if train_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if train_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        train_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)

        # only add augmentation to first reader (not validation reader)
        self.readers[0].add_preprocessing_layers(normalisation_layers +
                                                 augmentation_layers)

        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(normalisation_layers)

        # Checking num_classes is set correctly
        if self.classification_param.num_classes <= 1:
            raise ValueError(
                "Number of classes must be at least 2 for classification")
        for preprocessor in self.readers[0].preprocessors:
            if preprocessor.name == 'label_norm':
                if len(preprocessor.label_map[preprocessor.key[0]]
                       ) != self.classification_param.num_classes:
                    raise ValueError(
                        "Number of unique labels must be equal to "
                        "number of classes (check histogram_ref file)")

    def initialise_resize_sampler(self):
        self.sampler = [[
            ResizeSampler(reader=reader,
                          window_sizes=self.data_param,
                          batch_size=self.net_param.batch_size,
                          shuffle=self.is_training,
                          queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            postfix=self.action_param.output_postfix)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        else:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=self.classification_param.num_classes,
            w_initializer=InitializerFactory.get_initializer(
                name=self.net_param.weight_initializer),
            b_initializer=InitializerFactory.get_initializer(
                name=self.net_param.bias_initializer),
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def add_confusion_matrix_summaries_(self, outputs_collector, net_out,
                                        data_dict):
        """ This method defines several monitoring metrics that
        are derived from the confusion matrix """
        labels = tf.reshape(tf.cast(data_dict['label'], tf.int64), [-1])
        prediction = tf.reshape(tf.argmax(net_out, -1), [-1])
        num_classes = self.classification_param.num_classes
        conf_mat = tf.contrib.metrics.confusion_matrix(labels, prediction,
                                                       num_classes)
        conf_mat = tf.to_float(conf_mat)
        if self.classification_param.num_classes == 2:
            outputs_collector.add_to_collection(var=conf_mat[1][1],
                                                name='true_positives',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=conf_mat[1][0],
                                                name='false_negatives',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=conf_mat[0][1],
                                                name='false_positives',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=conf_mat[0][0],
                                                name='true_negatives',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
        else:
            outputs_collector.add_to_collection(var=conf_mat[tf.newaxis, :, :,
                                                             tf.newaxis],
                                                name='confusion_matrix',
                                                average_over_devices=True,
                                                summary_type='image',
                                                collection=TF_SUMMARIES)

        outputs_collector.add_to_collection(var=tf.trace(conf_mat),
                                            name='accuracy',
                                            average_over_devices=True,
                                            summary_type='scalar',
                                            collection=TF_SUMMARIES)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            self.patience = self.action_param.patience
            self.mode = self.action_param.early_stopping_mode
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                n_class=self.classification_param.num_classes,
                loss_type=self.action_param.loss_type)
            data_loss = loss_func(prediction=net_out,
                                  ground_truth=data_dict.get('label', None))
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss

            self.total_loss = loss

            grads = self.optimiser.compute_gradients(
                loss, colocate_gradients_with_ops=True)

            outputs_collector.add_to_collection(var=self.total_loss,
                                                name='total_loss',
                                                average_over_devices=True,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=self.total_loss,
                                                name='total_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='data_loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='data_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            self.add_confusion_matrix_summaries_(outputs_collector, net_out,
                                                 data_dict)
        else:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)
            tf.logging.info('net_out.shape may need to be resized: %s',
                            net_out.shape)
            output_prob = self.classification_param.output_prob
            num_classes = self.classification_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            self.initialise_aggregator()

    def interpret_output(self, batch_output):
        if not self.is_training:
            return self.output_decoder.decode_batch(
                {'csv': batch_output['window']}, batch_output['location'])
        return True

    def initialise_evaluator(self, eval_param):
        self.eval_param = eval_param
        self.evaluator = ClassificationEvaluator(self.readers[0],
                                                 self.classification_param,
                                                 eval_param)

    def add_inferred_output(self, data_param, task_param):
        return self.add_inferred_output_like(data_param, task_param, 'label')
class RegressionApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "REGRESSION"

    def __init__(self, net_param, action_param, action):
        BaseApplication.__init__(self)
        tf.logging.info('starting regression application')
        self.action = action

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.regression_param = None
        self.SUPPORTED_SAMPLING = {
            'uniform':
            (self.initialise_uniform_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'weighted':
            (self.initialise_weighted_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'resize':
            (self.initialise_resize_sampler, self.initialise_resize_sampler,
             self.initialise_resize_aggregator),
            'balanced':
            (self.initialise_balanced_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
        }

    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):

        self.data_param = data_param
        self.regression_param = task_param

        # initialise input image readers
        if self.is_training:
            reader_names = ('image', 'output', 'weight', 'sampler')
        elif self.is_inference:
            # in the inference process use `image` input only
            reader_names = ('image', )
        elif self.is_evaluation:
            reader_names = ('image', 'output', 'inferred')
        else:
            tf.logging.fatal('Action `%s` not supported. Expected one of %s',
                             self.action, self.SUPPORTED_PHASES)
            raise ValueError
        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(phase=reader_phase,
                                                        action=self.action)
        self.readers = [
            ImageReader(reader_names).initialise(data_param, task_param,
                                                 file_list)
            for file_list in file_lists
        ]

        # initialise input preprocessing layers
        mean_var_normaliser = MeanVarNormalisationLayer(image_name='image') \
            if self.net_param.whitening else None
        histogram_normaliser = HistogramNormalisationLayer(
            image_name='image',
            modalities=vars(task_param).get('image'),
            model_filename=self.net_param.histogram_ref_file,
            norm_type=self.net_param.norm_type,
            cutoff=self.net_param.cutoff,
            name='hist_norm_layer') \
            if (self.net_param.histogram_ref_file and
                self.net_param.normalisation) else None

        normalisation_layers = []
        if histogram_normaliser is not None:
            normalisation_layers.append(histogram_normaliser)
        if mean_var_normaliser is not None:
            normalisation_layers.append(mean_var_normaliser)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size,
                         mode=self.net_param.volume_padding_mode))

        # initialise training data augmentation layers
        augmentation_layers = []
        if self.is_training:
            train_param = self.action_param
            if train_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=train_param.random_flipping_axes))
            if train_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=train_param.scaling_percentage[0],
                        max_percentage=train_param.scaling_percentage[1]))
            if train_param.rotation_angle:
                rotation_layer = RandomRotationLayer()
                if train_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        train_param.rotation_angle)
                augmentation_layers.append(rotation_layer)
            if train_param.do_elastic_deformation:
                spatial_rank = list(self.readers[0].spatial_ranks.values())[0]
                augmentation_layers.append(
                    RandomElasticDeformationLayer(
                        spatial_rank=spatial_rank,
                        num_controlpoints=train_param.num_ctrl_points,
                        std_deformation_sigma=train_param.deformation_sigma,
                        proportion_to_augment=train_param.proportion_to_deform)
                )

        # only add augmentation to first reader (not validation reader)
        self.readers[0].add_preprocessing_layers(volume_padding_layer +
                                                 normalisation_layers +
                                                 augmentation_layers)

        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers)

    def initialise_uniform_sampler(self):
        self.sampler = [[
            UniformSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_weighted_sampler(self):
        self.sampler = [[
            WeightedSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_resize_sampler(self):
        self.sampler = [[
            ResizeSampler(reader=reader,
                          window_sizes=self.data_param,
                          batch_size=self.net_param.batch_size,
                          shuffle=self.is_training,
                          smaller_final_batch_mode=self.net_param.
                          smaller_final_batch_mode,
                          queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_grid_sampler(self):
        self.sampler = [[
            GridSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                spatial_window_size=self.action_param.spatial_window_size,
                window_border=self.action_param.border,
                smaller_final_batch_mode=self.net_param.
                smaller_final_batch_mode,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_balanced_sampler(self):
        self.sampler = [[
            BalancedSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order,
            postfix=self.action_param.output_postfix)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order,
            postfix=self.action_param.output_postfix)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        elif self.is_inference:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_aggregator(self):
        self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]()

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=1,
            w_initializer=InitializerFactory.get_initializer(
                name=self.net_param.weight_initializer),
            b_initializer=InitializerFactory.get_initializer(
                name=self.net_param.bias_initializer),
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(loss_type=self.action_param.loss_type)

            crop_layer = CropLayer(border=self.regression_param.loss_border)
            weight_map = data_dict.get('weight', None)
            weight_map = None if weight_map is None else crop_layer(weight_map)
            data_loss = loss_func(prediction=crop_layer(net_out),
                                  ground_truth=crop_layer(data_dict['output']),
                                  weight_map=weight_map)
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(
                loss, colocate_gradients_with_ops=True)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
        elif self.is_inference:
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)
            net_out = PostProcessingLayer('IDENTITY')(net_out)

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            self.initialise_aggregator()

    def interpret_output(self, batch_output):
        if self.is_inference:
            return self.output_decoder.decode_batch(batch_output['window'],
                                                    batch_output['location'])
        return True

    def initialise_evaluator(self, eval_param):
        self.eval_param = eval_param
        self.evaluator = RegressionEvaluator(self.readers[0],
                                             self.regression_param, eval_param)

    def add_inferred_output(self, data_param, task_param):
        return self.add_inferred_output_like(data_param, task_param, 'output')
Esempio n. 22
0
class SegmentationApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "SEGMENTATION"

    def __init__(self, net_param, action_param, action):
        super(SegmentationApplication, self).__init__()
        tf.logging.info('starting segmentation application')
        self.action = action

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.segmentation_param = None
        self.SUPPORTED_SAMPLING = {
            'uniform':
            (self.initialise_uniform_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'weighted':
            (self.initialise_weighted_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'resize':
            (self.initialise_resize_sampler, self.initialise_resize_sampler,
             self.initialise_resize_aggregator),
            'balanced':
            (self.initialise_balanced_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
        }

    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):

        self.data_param = data_param
        self.segmentation_param = task_param

        file_lists = self.get_file_lists(data_partitioner)
        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader({'image', 'label', 'weight', 'sampler'})
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)

        elif self.is_inference:
            # in the inference process use image input only
            inference_reader = ImageReader({'image'})
            file_list = pd.concat([
                data_partitioner.inference_files,
                data_partitioner.validation_files
            ],
                                  axis=0)
            file_list.index = range(file_list.shape[0])
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]
        elif self.is_evaluation:
            file_list = data_partitioner.inference_files
            reader = ImageReader({'image', 'label', 'inferred'})
            reader.initialise(data_param, task_param, file_list)
            self.readers = [reader]
        else:
            raise ValueError(
                'Action `{}` not supported. Expected one of {}'.format(
                    self.action, self.SUPPORTED_ACTIONS))

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        label_normalisers = None
        if self.net_param.histogram_ref_file and \
                task_param.label_normalisation:
            label_normalisers = [
                DiscreteLabelNormalisationLayer(
                    image_name='label',
                    modalities=vars(task_param).get('label'),
                    model_filename=self.net_param.histogram_ref_file)
            ]
            if self.is_evaluation:
                label_normalisers.append(
                    DiscreteLabelNormalisationLayer(
                        image_name='inferred',
                        modalities=vars(task_param).get('inferred'),
                        model_filename=self.net_param.histogram_ref_file))
                label_normalisers[-1].key = label_normalisers[0].key

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation and \
                (self.is_training or not task_param.output_prob):
            normalisation_layers.extend(label_normalisers)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if self.action_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        self.action_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)

            # add deformation layer
            if self.action_param.do_elastic_deformation:
                spatial_rank = list(self.readers[0].spatial_ranks.values())[0]
                augmentation_layers.append(
                    RandomElasticDeformationLayer(
                        spatial_rank=spatial_rank,
                        num_controlpoints=self.action_param.num_ctrl_points,
                        std_deformation_sigma=self.action_param.
                        deformation_sigma,
                        proportion_to_augment=self.action_param.
                        proportion_to_deform))

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size))

        # only add augmentation to first reader (not validation reader)
        self.readers[0].add_preprocessing_layers(volume_padding_layer +
                                                 normalisation_layers +
                                                 augmentation_layers)

        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers)

    def initialise_uniform_sampler(self):
        self.sampler = [[
            UniformSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_weighted_sampler(self):
        self.sampler = [[
            WeightedSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_resize_sampler(self):
        self.sampler = [[
            ResizeSampler(reader=reader,
                          data_param=self.data_param,
                          batch_size=self.net_param.batch_size,
                          shuffle_buffer=self.is_training,
                          queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_grid_sampler(self):
        self.sampler = [[
            GridSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                spatial_window_size=self.action_param.spatial_window_size,
                window_border=self.action_param.border,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_balanced_sampler(self):
        self.sampler = [[
            BalancedSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        elif self.is_inference:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_aggregator(self):
        self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]()

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=self.segmentation_param.num_classes,
            w_initializer=InitializerFactory.get_initializer(
                name=self.net_param.weight_initializer),
            b_initializer=InitializerFactory.get_initializer(
                name=self.net_param.bias_initializer),
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        # def data_net(for_training):
        #    with tf.name_scope('train' if for_training else 'validation'):
        #        sampler = self.get_sampler()[0][0 if for_training else -1]
        #        data_dict = sampler.pop_batch_op()
        #        image = tf.cast(data_dict['image'], tf.float32)
        #        return data_dict, self.net(image, is_training=for_training)

        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            # if self.action_param.validation_every_n > 0:
            #    data_dict, net_out = tf.cond(tf.logical_not(self.is_validation),
            #                                 lambda: data_net(True),
            #                                 lambda: data_net(False))
            # else:
            #    data_dict, net_out = data_net(True)
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            image = tf.unstack(image, axis=-1)
            net_out = self.net(
                {
                    MODALITIES[k]: tf.expand_dims(image[k], -1)
                    for k in range(2)
                },
                is_training=True)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type,
                softmax=self.segmentation_param.softmax)
            data_loss = loss_func(prediction=net_out,
                                  ground_truth=data_dict.get('label', None),
                                  weight_map=None)
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image*180.0, name='image',
            #    average_over_devices=False, summary_type='image3_sagittal',
            #    collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image, name='image',
            #    average_over_devices=False,
            #    collection=NETWORK_OUTPUT)

            # outputs_collector.add_to_collection(
            #    var=tf.reduce_mean(image), name='mean_image',
            #    average_over_devices=False, summary_type='scalar',
            #    collection=CONSOLE)
        elif self.is_inference:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            image = tf.unstack(image, axis=-1)
            net_out = self.net(
                {
                    MODALITIES[k]: tf.expand_dims(image[k], -1)
                    for k in range(2)
                },
                is_training=True)

            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            self.initialise_aggregator()

    def interpret_output(self, batch_output):
        if self.is_inference:
            return self.output_decoder.decode_batch(batch_output['window'],
                                                    batch_output['location'])
        return True

    def initialise_evaluator(self, eval_param):
        self.eval_param = eval_param
        self.evaluator = SegmentationEvaluator(self.readers[0],
                                               self.segmentation_param,
                                               eval_param)

    def add_inferred_output(self, data_param, task_param):
        return self.add_inferred_output_like(data_param, task_param, 'label')
Esempio n. 23
0
class RegressionRecApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "REGRESSION"

    def __init__(self, net_param, action_param, action):
        BaseApplication.__init__(self)
        tf.logging.info('starting recursive regression application')
        self.action = action

        self.net_param = net_param
        self.net2_param = copy.deepcopy(net_param)
        self.action_param = action_param
        self.regression_param = None

        self.data_param = None
        self.SUPPORTED_SAMPLING = {
            'uniform':
            (self.initialise_uniform_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'weighted':
            (self.initialise_weighted_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'resize':
            (self.initialise_resize_sampler, self.initialise_resize_sampler,
             self.initialise_resize_aggregator),
        }

    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):
        self.data_param = data_param
        self.regression_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            file_lists = []
            if self.action_param.validation_every_n > 0:
                file_lists.append(data_partitioner.train_files)
                file_lists.append(data_partitioner.validation_files)
            else:
                file_lists.append(data_partitioner.train_files)

            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(SUPPORTED_INPUT)
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        else:
            inference_reader = ImageReader(['image'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]

        mean_var_normaliser = MeanVarNormalisationLayer(image_name='image')
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size))
        for reader in self.readers:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers +
                                            augmentation_layers)

    def initialise_uniform_sampler(self):
        self.sampler = [[
            UniformSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_weighted_sampler(self):
        self.sampler = [[
            WeightedSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_resize_sampler(self):
        self.sampler = [[
            ResizeSampler(reader=reader,
                          window_sizes=self.data_param,
                          batch_size=self.net_param.batch_size,
                          shuffle=self.is_training,
                          smaller_final_batch_mode=self.net_param.
                          smaller_final_batch_mode,
                          queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_grid_sampler(self):
        self.sampler = [[
            GridSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                spatial_window_size=self.action_param.spatial_window_size,
                window_border=self.action_param.border,
                smaller_final_batch_mode=self.net_param.
                smaller_final_batch_mode,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_balanced_sampler(self):
        self.sampler = [[
            BalancedSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        else:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=1,
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

        self.net2 = ApplicationNetFactory.create(self.net2_param.name)(
            num_classes=1,
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net2_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            pct1_out = self.net(image, self.is_training)
            res2_out = self.net2(tf.concat([image, pct1_out], 4),
                                 self.is_training)
            pct2_out = tf.add(pct1_out, res2_out)
            res3_out = self.net2(tf.concat([image, pct2_out], 4),
                                 self.is_training)
            pct3_out = tf.add(pct2_out, res3_out)
            #res4_out = self.net2(tf.concat([image, pct3_out],4), self.is_training)
            #pct4_out = tf.add(pct3_out,res4_out)
            #net_out = self.net(image, is_training=self.is_training)
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(loss_type=self.action_param.loss_type)

            crop_layer = CropLayer(border=self.regression_param.loss_border,
                                   name='crop-88')

            data_loss1 = loss_func(
                prediction=crop_layer(pct1_out),
                ground_truth=crop_layer(data_dict.get('output', None)),
                weight_map=None if data_dict.get('weight', None) is None else
                crop_layer(data_dict.get('weight', None)))
            data_loss2 = loss_func(
                prediction=crop_layer(pct2_out),
                ground_truth=crop_layer(data_dict.get('output', None)),
                weight_map=None if data_dict.get('weight', None) is None else
                crop_layer(data_dict.get('weight', None)))
            data_loss3 = loss_func(
                prediction=crop_layer(pct3_out),
                ground_truth=crop_layer(data_dict.get('output', None)),
                weight_map=None if data_dict.get('weight', None) is None else
                crop_layer(data_dict.get('weight', None)))

            #prediction = crop_layer(net_out)
            #ground_truth = crop_layer(data_dict.get('output', None))
            #weight_map = None if data_dict.get('weight', None) is None \
            #else crop_layer(data_dict.get('weight', None))
            #data_loss = loss_func(prediction=prediction,
            #ground_truth=ground_truth,
            #weight_map=weight_map)

            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = reg_loss + data_loss1 + data_loss2 + data_loss3
            else:
                loss = data_loss1 + data_loss2 + data_loss3
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=loss,
                                                name='Loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss1,
                                                name='data_loss1',
                                                average_over_devices=True,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss2,
                                                name='data_loss2',
                                                average_over_devices=True,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss3,
                                                name='data_loss3',
                                                average_over_devices=True,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss1,
                                                name='data_loss1',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=data_loss2,
                                                name='data_loss2',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=data_loss3,
                                                name='data_loss3',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=loss,
                                                name='LossSum',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
#            outputs_collector.add_to_collection(
#                var=pct3_out, name="pct3_out",
#                average_over_devices=True, summary_type="image3_axial",
#                collection=TF_SUMMARIES)

        else:
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            #net_out = self.net(image, is_training=self.is_training)
            pct1_out = self.net(image, self.is_training)
            res2_out = self.net2(tf.concat([image, pct1_out], 4),
                                 self.is_training)
            pct2_out = tf.add(pct1_out, res2_out)
            res3_out = self.net2(tf.concat([image, pct2_out], 4),
                                 self.is_training)
            pct3_out = tf.add(pct2_out, res3_out)
            res4_out = self.net2(tf.concat([image, pct3_out], 4),
                                 self.is_training)
            pct4_out = tf.add(pct3_out, res4_out)
            crop_layer = CropLayer(border=0, name='crop-88')
            post_process_layer = PostProcessingLayer('IDENTITY')
            net_out = post_process_layer(crop_layer(pct4_out))

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            init_aggregator = \
                self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
            init_aggregator()

    def interpret_output(self, batch_output):
        if not self.is_training:
            return self.output_decoder.decode_batch(
                {'window_image': batch_output['window']},
                batch_output['location'])
        else:
            return True
class SegmentationApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "SEGMENTATION"

    def __init__(self, net_param, action_param, action):
        super(SegmentationApplication, self).__init__()

        tf.logging.info('starting segmentation application')
        self.action = action
        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.segmentation_param = None
        self.SUPPORTED_SAMPLING = {
            'uniform': (self.initialise_uniform_sampler,
                        self.initialise_grid_sampler,
                        self.initialise_grid_aggregator),
            'weighted': (self.initialise_weighted_sampler,
                         self.initialise_grid_sampler,
                         self.initialise_grid_aggregator),
            'resize': (self.initialise_resize_sampler,
                       self.initialise_resize_sampler,
                       self.initialise_resize_aggregator),
            'balanced': (self.initialise_balanced_sampler,
                         self.initialise_grid_sampler,
                         self.initialise_grid_aggregator),
        }



        self.learning_rate = None
        self.current_lr = tf.constant(0)

    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):

        self.data_param = data_param
        self.segmentation_param = task_param

        # initialise input image readers
        if self.is_training:
            reader_names = ('image', 'label', 'weight_map', 'sampler')
        elif self.is_inference:
            # in the inference process use `image` input only
            reader_names = ('image',)
        elif self.is_evaluation:
            reader_names = ('image', 'label', 'inferred')
        else:
            tf.logging.fatal(
                'Action `%s` not supported. Expected one of %s',
                self.action, self.SUPPORTED_PHASES)
            raise ValueError
        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(
            phase=reader_phase, action=self.action)
        self.readers = [
            ImageReader(reader_names).initialise(
                data_param, task_param, file_list) for file_list in file_lists]

        # initialise input preprocessing layers
        foreground_masking_layer = BinaryMaskingLayer(
            type_str=self.net_param.foreground_type,
            multimod_fusion=self.net_param.multimod_foreground_type,
            threshold=0.0) \
            if self.net_param.normalise_foreground_only else None
        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer) \
            if self.net_param.whitening else None
        percentile_normaliser = PercentileNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer, cutoff=self.net_param.cutoff) \
            if self.net_param.percentile_normalisation else None
        histogram_normaliser = HistogramNormalisationLayer(
            image_name='image',
            modalities=vars(task_param).get('image'),
            model_filename=self.net_param.histogram_ref_file,
            binary_masking_func=foreground_masking_layer,
            norm_type=self.net_param.norm_type,
            cutoff=self.net_param.cutoff,
            name='hist_norm_layer') \
            if (self.net_param.histogram_ref_file and
                self.net_param.normalisation) else None
        label_normalisers = None
        if self.net_param.histogram_ref_file and \
                task_param.label_normalisation:
            label_normalisers = [DiscreteLabelNormalisationLayer(
                image_name='label',
                modalities=vars(task_param).get('label'),
                model_filename=self.net_param.histogram_ref_file)]
            if self.is_evaluation:
                label_normalisers.append(
                    DiscreteLabelNormalisationLayer(
                        image_name='inferred',
                        modalities=vars(task_param).get('inferred'),
                        model_filename=self.net_param.histogram_ref_file))
                label_normalisers[-1].key = label_normalisers[0].key

        normalisation_layers = []
        if histogram_normaliser is not None:
            normalisation_layers.append(histogram_normaliser)
        if mean_var_normaliser is not None:
            normalisation_layers.append(mean_var_normaliser)
        if percentile_normaliser is not None:
            normalisation_layers.append(percentile_normaliser)
        if task_param.label_normalisation and \
                (self.is_training or not task_param.output_prob):
            normalisation_layers.extend(label_normalisers)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(PadLayer(
                image_name=SUPPORTED_INPUT,
                border=self.net_param.volume_padding_size,
                mode=self.net_param.volume_padding_mode))

        # initialise training data augmentation layers
        augmentation_layers = []
        if self.is_training:
            train_param = self.action_param
            if train_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=train_param.random_flipping_axes))
            if train_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=train_param.scaling_percentage[0],
                    max_percentage=train_param.scaling_percentage[1],
                    antialiasing=train_param.antialiasing))
            if train_param.rotation_angle or \
                    train_param.rotation_angle_x or \
                    train_param.rotation_angle_y or \
                    train_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if train_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        train_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        train_param.rotation_angle_x,
                        train_param.rotation_angle_y,
                        train_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)
            if train_param.do_elastic_deformation:
                spatial_rank = list(self.readers[0].spatial_ranks.values())[0]
                augmentation_layers.append(RandomElasticDeformationLayer(
                    spatial_rank=spatial_rank,
                    num_controlpoints=train_param.num_ctrl_points,
                    std_deformation_sigma=train_param.deformation_sigma,
                    proportion_to_augment=train_param.proportion_to_deform))

        # only add augmentation to first reader (not validation reader)
        self.readers[0].add_preprocessing_layers(
            volume_padding_layer + normalisation_layers + augmentation_layers)

        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(
                volume_padding_layer + normalisation_layers)

    def initialise_uniform_sampler(self):
        self.sampler = [[UniformSampler(
            reader=reader,
            window_sizes=self.data_param,
            batch_size=self.net_param.batch_size,
            windows_per_image=self.action_param.sample_per_volume,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_weighted_sampler(self):
        self.sampler = [[WeightedSampler(
            reader=reader,
            window_sizes=self.data_param,
            batch_size=self.net_param.batch_size,
            windows_per_image=self.action_param.sample_per_volume,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_resize_sampler(self):
        self.sampler = [[ResizeSampler(
            reader=reader,
            window_sizes=self.data_param,
            batch_size=self.net_param.batch_size,
            shuffle=self.is_training,
            smaller_final_batch_mode=self.net_param.smaller_final_batch_mode,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_grid_sampler(self):
        self.sampler = [[GridSampler(
            reader=reader,
            window_sizes=self.data_param,
            batch_size=self.net_param.batch_size,
            spatial_window_size=self.action_param.spatial_window_size,
            window_border=self.action_param.border,
            smaller_final_batch_mode=self.net_param.smaller_final_batch_mode,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_balanced_sampler(self):
        self.sampler = [[BalancedSampler(
            reader=reader,
            window_sizes=self.data_param,
            batch_size=self.net_param.batch_size,
            windows_per_image=self.action_param.sample_per_volume,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order,
            postfix=self.action_param.output_postfix)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order,
            postfix=self.action_param.output_postfix)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        elif self.is_inference:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_aggregator(self):
        self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]()

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=self.segmentation_param.num_classes,
            w_initializer=InitializerFactory.get_initializer(
                name=self.net_param.weight_initializer),
            b_initializer=InitializerFactory.get_initializer(
                name=self.net_param.bias_initializer),
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):

        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {'is_training': self.is_training,
                        'keep_prob': self.net_param.keep_prob}
            net_out = self.net(image, **net_args)

            with tf.name_scope('Optimiser'):
                self.learning_rate = tf.placeholder(tf.float32, shape=[])
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.learning_rate)
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type,
                softmax=self.segmentation_param.softmax)
            data_loss = loss_func(
                prediction=net_out,
                ground_truth=data_dict.get('label', None),
                weight_map=data_dict.get('weight_map', None))
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss

            # Get all vars
            to_optimise = tf.trainable_variables()
            vars_to_freeze = \
                self.action_param.vars_to_freeze or \
                self.action_param.vars_to_restore
            if vars_to_freeze:
                import re
                var_regex = re.compile(vars_to_freeze)
                # Only optimise vars that are not frozen
                to_optimise = \
                    [v for v in to_optimise if not var_regex.search(v.name)]
                tf.logging.info(
                    "Optimizing %d out of %d trainable variables, "
                    "the other variables fixed (--vars_to_freeze %s)",
                    len(to_optimise),
                    len(tf.trainable_variables()),
                    vars_to_freeze)

            grads = self.optimiser.compute_gradients(
                loss, var_list=to_optimise, colocate_gradients_with_ops=True)
             # clip gradients
            gradients, variables = zip(*grads)
            gradients, _ = tf.clip_by_global_norm(gradients, self.action_param.gradient_clipping_value)
            grads = list(zip(gradients, variables))
            gnorm = tf.global_norm(list(gradients))


            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(
                var=data_loss, name='loss',
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=gnorm, name='gnorm',
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=gnorm, name='gnorm',
                average_over_devices=False, summary_type='scalar',
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=data_loss, name='loss',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=self.learning_rate, name='lr',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)

            #outputs_collector.add_to_collection(
            #    var=image, name='image',
            #    average_over_devices=False, summary_type='image3_sagittal',
            #    collection=TF_SUMMARIES)
            #outputs_collector.add_to_collection(
            #    var=net_out, name='output',
            #    average_over_devices=False, summary_type='image3_sagittal',
            #    collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image*180.0, name='image',
            #    average_over_devices=False, summary_type='image3_sagittal',
            #    collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image, name='image',
            #    average_over_devices=False,
            #    collection=NETWORK_OUTPUT)

            # outputs_collector.add_to_collection(
            #    var=tf.reduce_mean(image), name='mean_image',
            #    average_over_devices=False, summary_type='scalar',
            #    collection=CONSOLE)
        elif self.is_inference:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {'is_training': self.is_training,
                        'keep_prob': self.net_param.keep_prob}
            net_out = self.net(image, **net_args)

            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)

            outputs_collector.add_to_collection(
                var=net_out, name='window',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'], name='location',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            self.initialise_aggregator()

    def interpret_output(self, batch_output):
        if self.is_inference:
            return self.output_decoder.decode_batch(
                batch_output['window'], batch_output['location'])
        return True

    def initialise_evaluator(self, eval_param):
        self.eval_param = eval_param
        self.evaluator = SegmentationEvaluator(self.readers[0],
                                               self.segmentation_param,
                                               eval_param)

    def add_inferred_output(self, data_param, task_param):
        return self.add_inferred_output_like(data_param, task_param, 'label')

    def set_iteration_update(self, iteration_message):
        """
        This function will be called by the application engine at each
        iteration.
        """
        current_iter = iteration_message.current_iter
        if iteration_message.is_training:
            if current_iter < self.action_param.warmup:
                self.current_lr = self.action_param.lr/(1. + math.exp(10 * (-current_iter/self.action_param.warmup+0.5)))
            else:
                self.current_lr = self.action_param.lr * pow(
                    self.action_param.lr_gamma,
                    ((current_iter - self.action_param.warmup) // self.action_param.lr_step_size))

            iteration_message.data_feed_dict[self.is_validation] = False
            iteration_message.data_feed_dict[self.learning_rate] = self.current_lr
        elif iteration_message.is_validation:
            iteration_message.data_feed_dict[self.is_validation] = True
            iteration_message.data_feed_dict[self.learning_rate] = self.current_lr
Esempio n. 25
0
class SegmentationApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "SEGMENTATION"

    def __init__(self, net_param, action_param, action):
        super(SegmentationApplication, self).__init__()
        tf.logging.info('starting segmentation application')
        self.action = action

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.segmentation_param = None
        self.SUPPORTED_SAMPLING = {
            'uniform':
            (self.initialise_uniform_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'weighted':
            (self.initialise_weighted_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'resize':
            (self.initialise_resize_sampler, self.initialise_resize_sampler,
             self.initialise_resize_aggregator),
            'balanced':
            (self.initialise_balanced_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
        }

    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):

        self.data_param = data_param
        self.segmentation_param = task_param

        # initialise input image readers
        if self.is_training:
            reader_names = ('image', 'label', 'weight', 'sampler')
        elif self.is_inference:
            # in the inference process use `image` input only
            reader_names = ('image', )
        elif self.is_evaluation:
            reader_names = ('image', 'label', 'inferred')
        else:
            tf.logging.fatal('Action `%s` not supported. Expected one of %s',
                             self.action, self.SUPPORTED_PHASES)
            raise ValueError
        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(phase=reader_phase,
                                                        action=self.action)
        self.readers = [
            ImageReader(reader_names).initialise(data_param, task_param,
                                                 file_list)
            for file_list in file_lists
        ]

        # initialise input preprocessing layers
        foreground_masking_layer = BinaryMaskingLayer(
            type_str=self.net_param.foreground_type,
            multimod_fusion=self.net_param.multimod_foreground_type,
            threshold=0.0) \
            if self.net_param.normalise_foreground_only else None
        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer) \
            if self.net_param.whitening else None
        histogram_normaliser = HistogramNormalisationLayer(
            image_name='image',
            modalities=vars(task_param).get('image'),
            model_filename=self.net_param.histogram_ref_file,
            binary_masking_func=foreground_masking_layer,
            norm_type=self.net_param.norm_type,
            cutoff=self.net_param.cutoff,
            name='hist_norm_layer') \
            if (self.net_param.histogram_ref_file and
                self.net_param.normalisation) else None
        rgb_normaliser = RGBHistogramEquilisationLayer(
            image_name='image', name='rbg_norm_layer'
        ) if self.net_param.rgb_normalisation else None
        label_normalisers = None
        if self.net_param.histogram_ref_file and \
                task_param.label_normalisation:
            label_normalisers = [
                DiscreteLabelNormalisationLayer(
                    image_name='label',
                    modalities=vars(task_param).get('label'),
                    model_filename=self.net_param.histogram_ref_file)
            ]
            if self.is_evaluation:
                label_normalisers.append(
                    DiscreteLabelNormalisationLayer(
                        image_name='inferred',
                        modalities=vars(task_param).get('inferred'),
                        model_filename=self.net_param.histogram_ref_file))
                label_normalisers[-1].key = label_normalisers[0].key

        normalisation_layers = []
        if histogram_normaliser is not None:
            normalisation_layers.append(histogram_normaliser)
        if rgb_normaliser is not None:
            normalisation_layers.append(rgb_normaliser)
        if mean_var_normaliser is not None:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation and \
                (self.is_training or not task_param.output_prob):
            normalisation_layers.extend(label_normalisers)

        volume_padding_layer = [
            PadLayer(image_name=SUPPORTED_INPUT,
                     border=self.net_param.volume_padding_size,
                     mode=self.net_param.volume_padding_mode,
                     pad_to=self.net_param.volume_padding_to_size)
        ]
        # initialise training data augmentation layers
        augmentation_layers = []
        if self.is_training:
            train_param = self.action_param
            self.patience = train_param.patience
            self.mode = self.action_param.early_stopping_mode
            if train_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=train_param.random_flipping_axes))
            if train_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=train_param.scaling_percentage[0],
                        max_percentage=train_param.scaling_percentage[1],
                        antialiasing=train_param.antialiasing,
                        isotropic=train_param.isotropic_scaling))
            if train_param.rotation_angle or \
                    train_param.rotation_angle_x or \
                    train_param.rotation_angle_y or \
                    train_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if train_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        train_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        train_param.rotation_angle_x,
                        train_param.rotation_angle_y,
                        train_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)
            if train_param.do_elastic_deformation:
                spatial_rank = list(self.readers[0].spatial_ranks.values())[0]
                augmentation_layers.append(
                    RandomElasticDeformationLayer(
                        spatial_rank=spatial_rank,
                        num_controlpoints=train_param.num_ctrl_points,
                        std_deformation_sigma=train_param.deformation_sigma,
                        proportion_to_augment=train_param.proportion_to_deform)
                )

        # only add augmentation to first reader (not validation reader)
        self.readers[0].add_preprocessing_layers(volume_padding_layer +
                                                 normalisation_layers +
                                                 augmentation_layers)

        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers)

        # Checking num_classes is set correctly
        if self.segmentation_param.num_classes <= 1:
            raise ValueError(
                "Number of classes must be at least 2 for segmentation")
        for preprocessor in self.readers[0].preprocessors:
            if preprocessor.name == 'label_norm':
                if len(preprocessor.label_map[preprocessor.key[0]]
                       ) != self.segmentation_param.num_classes:
                    raise ValueError(
                        "Number of unique labels must be equal to "
                        "number of classes (check histogram_ref file)")

    def initialise_uniform_sampler(self):
        self.sampler = [[
            UniformSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_weighted_sampler(self):
        self.sampler = [[
            WeightedSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_resize_sampler(self):
        self.sampler = [[
            ResizeSampler(reader=reader,
                          window_sizes=self.data_param,
                          batch_size=self.net_param.batch_size,
                          shuffle=self.is_training,
                          smaller_final_batch_mode=self.net_param.
                          smaller_final_batch_mode,
                          queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_grid_sampler(self):
        self.sampler = [[
            GridSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                spatial_window_size=self.action_param.spatial_window_size,
                window_border=self.action_param.border,
                smaller_final_batch_mode=self.net_param.
                smaller_final_batch_mode,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_balanced_sampler(self):
        self.sampler = [[
            BalancedSampler(
                reader=reader,
                window_sizes=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order,
            postfix=self.action_param.output_postfix,
            fill_constant=self.action_param.fill_constant)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order,
            postfix=self.action_param.output_postfix)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        elif self.is_inference:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_aggregator(self):
        self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]()

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=self.segmentation_param.num_classes,
            w_initializer=InitializerFactory.get_initializer(
                name=self.net_param.weight_initializer),
            b_initializer=InitializerFactory.get_initializer(
                name=self.net_param.bias_initializer),
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        def mixup_switch_sampler(for_training):
            # get first set of samples
            d_dict = switch_sampler(for_training=for_training)

            mix_fields = ('image', 'weight', 'label')

            if not for_training:
                with tf.name_scope('nomix'):
                    # ensure label is appropriate for dense loss functions
                    ground_truth = tf.cast(d_dict['label'], tf.int32)
                    one_hot = tf.one_hot(
                        tf.squeeze(ground_truth, axis=-1),
                        depth=self.segmentation_param.num_classes)
                    d_dict['label'] = one_hot
            else:
                with tf.name_scope('mixup'):
                    # get the mixing parameter from the Beta distribution
                    alpha = self.segmentation_param.mixup_alpha
                    beta = tf.distributions.Beta(alpha,
                                                 alpha)  # 1, 1: uniform:
                    rand_frac = beta.sample()

                    # get another minibatch
                    d_dict_to_mix = switch_sampler(for_training=True)

                    # look at binarised labels: sort them
                    if self.segmentation_param.mix_match:
                        # sum up the positive labels to sort by their volumes
                        inds1 = tf.argsort(
                            tf.map_fn(tf.reduce_sum,
                                      tf.cast(d_dict['label'], tf.int64)))
                        inds2 = tf.argsort(
                            tf.map_fn(
                                tf.reduce_sum,
                                tf.cast(d_dict_to_mix['label'] > 0, tf.int64)))
                        for field in [
                                field for field in mix_fields
                                if field in d_dict
                        ]:
                            d_dict[field] = tf.gather(d_dict[field],
                                                      indices=inds1)
                            # note: sorted for opposite directions for d_dict_to_mix
                            d_dict_to_mix[field] = tf.gather(
                                d_dict_to_mix[field], indices=inds2[::-1])

                    # making the labels dense and one-hot
                    for d in (d_dict, d_dict_to_mix):
                        ground_truth = tf.cast(d['label'], tf.int32)
                        one_hot = tf.one_hot(
                            tf.squeeze(ground_truth, axis=-1),
                            depth=self.segmentation_param.num_classes)
                        d['label'] = one_hot

                    # do the mixing for any fields that are relevant and present
                    mixed_up = {
                        field: d_dict[field] * rand_frac +
                        d_dict_to_mix[field] * (1 - rand_frac)
                        for field in mix_fields if field in d_dict
                    }
                    # reassign all relevant values in d_dict
                    d_dict.update(mixed_up)

            return d_dict

        if self.is_training:
            if not self.segmentation_param.do_mixup:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                # mix up the samples if not in validation phase
                data_dict = tf.cond(
                    tf.logical_not(self.is_validation),
                    lambda: mixup_switch_sampler(for_training=True),
                    lambda: mixup_switch_sampler(for_training=False
                                                 ))  # don't mix the validation

            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type,
                softmax=self.segmentation_param.softmax)
            data_loss = loss_func(prediction=net_out,
                                  ground_truth=data_dict.get('label', None),
                                  weight_map=data_dict.get('weight', None))
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss

            # Get all vars
            to_optimise = tf.trainable_variables()
            vars_to_freeze = \
                self.action_param.vars_to_freeze or \
                self.action_param.vars_to_restore
            if vars_to_freeze:
                import re
                var_regex = re.compile(vars_to_freeze)
                # Only optimise vars that are not frozen
                to_optimise = \
                    [v for v in to_optimise if not var_regex.search(v.name)]
                tf.logging.info(
                    "Optimizing %d out of %d trainable variables, "
                    "the other variables fixed (--vars_to_freeze %s)",
                    len(to_optimise), len(tf.trainable_variables()),
                    vars_to_freeze)

            grads = self.optimiser.compute_gradients(
                loss, var_list=to_optimise, colocate_gradients_with_ops=True)

            self.total_loss = loss

            # collecting gradients variables
            gradients_collector.add_to_collection([grads])

            # collecting output variables
            outputs_collector.add_to_collection(var=self.total_loss,
                                                name='total_loss',
                                                average_over_devices=True,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=self.total_loss,
                                                name='total_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image*180.0, name='image',
            #    average_over_devices=False, summary_type='image3_sagittal',
            #    collection=TF_SUMMARIES)

            # outputs_collector.add_to_collection(
            #    var=image, name='image',
            #    average_over_devices=False,
            #    collection=NETWORK_OUTPUT)

            # outputs_collector.add_to_collection(
            #    var=tf.reduce_mean(image), name='mean_image',
            #    average_over_devices=False, summary_type='scalar',
            #    collection=CONSOLE)
        elif self.is_inference:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)

            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            self.initialise_aggregator()

    def interpret_output(self, batch_output):
        if self.is_inference:
            return self.output_decoder.decode_batch(
                {'window_seg': batch_output['window']},
                batch_output['location'])
        return True

    def initialise_evaluator(self, eval_param):
        self.eval_param = eval_param
        self.evaluator = SegmentationEvaluator(self.readers[0],
                                               self.segmentation_param,
                                               eval_param)

    def add_inferred_output(self, data_param, task_param):
        return self.add_inferred_output_like(data_param, task_param, 'label')
Esempio n. 26
0
 def initialise_resize_aggregator(self):
     self.output_decoder = ResizeSamplesAggregator(
         image_reader=self.readers[0],
         output_path=self.action_param.save_seg_dir,
         window_border=self.action_param.border,
         interp_order=self.action_param.output_interp_order)
Esempio n. 27
0
class SegmentationApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "SEGMENTATION"

    def __init__(self, net_param, action_param, is_training):
        super(SegmentationApplication, self).__init__()
        tf.logging.info('starting segmentation application')
        self.is_training = is_training

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.segmentation_param = None
        self.SUPPORTED_SAMPLING = {
            'uniform': (self.initialise_uniform_sampler,
                        self.initialise_grid_sampler,
                        self.initialise_grid_aggregator),
            'weighted': (self.initialise_weighted_sampler,
                         self.initialise_grid_sampler,
                         self.initialise_grid_aggregator),
            'resize': (self.initialise_resize_sampler,
                       self.initialise_resize_sampler,
                       self.initialise_resize_aggregator),
        }
        self.loss_variable = None
        self.first_slice = None
        self.netOut = None
        self.GROUNDTRUTH = None
        self.PREDICTION = None
        self.CONT = 0
        self.SUMA = None
        self.GRADS = None
        self.CONV_KERNEL = None
        #self.IDS = None

    def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None):

        self.data_param = data_param
        self.segmentation_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            file_lists = []
            if self.action_param.validation_every_n > 0:
                file_lists.append(data_partitioner.train_files)
                file_lists.append(data_partitioner.validation_files)
            else:
                file_lists.append(data_partitioner.all_files)

            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(SUPPORTED_INPUT)
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)

        else:  # in the inference process use image input only
            inference_reader = ImageReader(['image'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        label_normaliser = None
        if self.net_param.histogram_ref_file:
            label_normaliser = DiscreteLabelNormalisationLayer(
                image_name='label',
                modalities=vars(task_param).get('label'),
                model_filename=self.net_param.histogram_ref_file)

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation:
            normalisation_layers.append(label_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=self.action_param.scaling_percentage[0],
                    max_percentage=self.action_param.scaling_percentage[1]))
            if self.action_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if self.action_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        self.action_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(PadLayer(
                image_name=SUPPORTED_INPUT,
                border=self.net_param.volume_padding_size))

        for reader in self.readers:
            reader.add_preprocessing_layers(
                volume_padding_layer +
                normalisation_layers +
                augmentation_layers)

    def initialise_uniform_sampler(self):
        self.sampler = [[UniformSampler(
            reader=reader,
            data_param=self.data_param,
            batch_size=self.net_param.batch_size,
            windows_per_image=self.action_param.sample_per_volume,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_weighted_sampler(self):
        self.sampler = [[WeightedSampler(
            reader=reader,
            data_param=self.data_param,
            batch_size=self.net_param.batch_size,
            windows_per_image=self.action_param.sample_per_volume,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_resize_sampler(self):
        self.sampler = [[ResizeSampler(
            reader=reader,
            data_param=self.data_param,
            batch_size=self.net_param.batch_size,
            shuffle_buffer=self.is_training,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_grid_sampler(self):
        self.sampler = [[GridSampler(
            reader=reader,
            data_param=self.data_param,
            batch_size=self.net_param.batch_size,
            spatial_window_size=self.action_param.spatial_window_size,
            window_border=self.action_param.border,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        else:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_network(self):
        print("Initializing network")
        #IMPORTING REGULARIZERS w AND b
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        #---------W_INI = "he_normal", -- application_factory.py
        w_ini= InitializerFactory.get_initializer(name=self.net_param.weight_initializer)

        print("wWwWwWwWwWWWWWWwWWWWWWWWWweight_initializer; ", self.net_param.weight_initializer)
        print("NNNNNNNNname of application: ", self.net_param.name)

        #SELF.NET_PARAM.NAME = DENSE_VET 
        #Create dense_vnet and initialize with regularizers and activ funcs.
        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=self.segmentation_param.num_classes,
            w_initializer=w_ini,
            b_initializer=InitializerFactory.get_initializer(
                name=self.net_param.bias_initializer),
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)
        #print("Wwwwwwwwwwwwwwwww_INITIALIZER :", w_ini)
        #print("BBBBBBBBBBBBBBBBB_INITIALIZER :", b_initializer)


    def connect_data_and_network(self,outputs_collector=None, gradients_collector=None):
        #def data_net(for_training):
        #    with tf.name_scope('train' if for_training else 'validation'):
        #        sampler = self.get_sampler()[0][0 if for_training else -1]
        #        data_dict = sampler.pop_batch_op()
        #        image = tf.cast(data_dict['image'], tf.float32)
        #        return data_dict, self.net(image, is_training=for_training)

        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:

            print("-CONNECT DATA AND NETWORK -TRAINING---------------")
            #if self.action_param.validation_every_n > 0:
            #    data_dict, net_out = tf.cond(tf.logical_not(self.is_validation),
            #                                 lambda: data_net(True),
            #                                 lambda: data_net(False))
            #else:
            #    data_dict, net_out = data_net(True)
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(for_training=True),
                                    lambda: switch_sampler(for_training=False))
            else:
                data_dict = switch_sampler(for_training=True)
            
            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)
            

            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)       #ADAM OPTIMISER
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)


            
            print("####################################nombre del optimiser: ",self.action_param.optimiser)
            print("##############################3learning rate: ", self.action_param.lr)

            #loss func
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type)            

            ground_truth=data_dict.get('label', None)
            weight_map=data_dict.get('weight', None)
           
            #data_loss, ONEHOT, IDS= loss_func(
            data_loss = loss_func(
                prediction=net_out,
                ground_truth=data_dict.get('label', None),
                weight_map=data_dict.get('weight', None))            
            
            ################################################################
            ################################################################
            #setting up printing variables
            
            self.loss_variable = data_loss
            firstSlice = ground_truth            
            self.first_slice = firstSlice 
            #self.first_slice = tf.squeeze(tf.slice(firstSlice, [0,0,0,60,0], [1,103,103,1,1]))
            #self.first_slice_cut = tf.slice(firstSlice, [0,52,52,60,1], [1,30,30,1,1])
            netOut = tf.nn.softmax(net_out)
            self.netOut = netOut
            #self.netOut = tf.squeeze(netOut[0,50,50,1,:])
            
            GROUNDTRUTH, PREDICTION, CONT = loss_func.return_loss_args()
            self.GROUNDTRUTH = GROUNDTRUTH
            self.PREDICTION = PREDICTION
            self.CONT = CONT
            self.SUMA = loss_func.SUMA
            
            print("Salio del seteo de variable en connect data and net")
            ################################################################
            ################################################################            
            #calculating regularizers
            reg_losses = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
            
            print("############## que que e isso: ", tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss


            grads = self.optimiser.compute_gradients(loss)

            
            #grads2 = self.optimiser.compute_gradients(loss,[prediction])
            self.GRADS = grads
            #print("#############GRADIENDSSSSSSSSSS", grads)

            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(
                var=data_loss, name='dice_loss',
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=data_loss, name='dice_loss',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)

            #####################
            outputs_collector.add_to_collection(
                var=image*180.0, name='image',
                average_over_devices=False, summary_type='image3_sagittal',
                collection=TF_SUMMARIES)

            outputs_collector.add_to_collection(
                var=image, name='image',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)

            outputs_collector.add_to_collection(
                var=tf.reduce_mean(image), name='mean_image',
                average_over_devices=False, summary_type='scalar',
                collection=CONSOLE)
        else:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)

            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)

            outputs_collector.add_to_collection(
                var=net_out, name='window',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'], name='location',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            init_aggregator = \
                self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
            init_aggregator()

    def return_loss_variable(self):
        return self.loss_variable

    def return_first_slice(self):
        return self.first_slice, self.first_slice.get_shape(), self.netOut, self.netOut.get_shape()

    def return_seg_args(self):        
        return self.GROUNDTRUTH, self.PREDICTION, self.CONT

    def interpret_output(self, batch_output):
        if not self.is_training:
            return self.output_decoder.decode_batch(
                batch_output['window'], batch_output['location'])
        return True
Esempio n. 28
0
class RegressionApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "REGRESSION"

    def __init__(self, net_param, action_param, is_training):
        BaseApplication.__init__(self)
        tf.logging.info('starting regression application')
        self.is_training = is_training

        self.net_param = net_param
        self.action_param = action_param
        self.regression_param = None

        self.data_param = None
        self.SUPPORTED_SAMPLING = {
            'uniform': (self.initialise_uniform_sampler,
                        self.initialise_grid_sampler,
                        self.initialise_grid_aggregator),
            'weighted': (self.initialise_weighted_sampler,
                         self.initialise_grid_sampler,
                         self.initialise_grid_aggregator),
            'resize': (self.initialise_resize_sampler,
                       self.initialise_resize_sampler,
                       self.initialise_resize_aggregator),
        }

    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):
        self.data_param = data_param
        self.regression_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            file_lists = []
            if self.action_param.validation_every_n > 0:
                file_lists.append(data_partitioner.train_files)
                file_lists.append(data_partitioner.validation_files)
            else:
                file_lists.append(data_partitioner.train_files)

            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(SUPPORTED_INPUT)
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        else:
            inference_reader = ImageReader(['image'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image')
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=self.action_param.scaling_percentage[0],
                    max_percentage=self.action_param.scaling_percentage[1]))
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(PadLayer(
                image_name=SUPPORTED_INPUT,
                border=self.net_param.volume_padding_size))
        for reader in self.readers:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers +
                                            augmentation_layers)

    def initialise_uniform_sampler(self):
        self.sampler = [[UniformSampler(
            reader=reader,
            data_param=self.data_param,
            batch_size=self.net_param.batch_size,
            windows_per_image=self.action_param.sample_per_volume,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_weighted_sampler(self):
        self.sampler = [[WeightedSampler(
            reader=reader,
            data_param=self.data_param,
            batch_size=self.net_param.batch_size,
            windows_per_image=self.action_param.sample_per_volume,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_resize_sampler(self):
        self.sampler = [[ResizeSampler(
            reader=reader,
            data_param=self.data_param,
            batch_size=self.net_param.batch_size,
            shuffle_buffer=self.is_training,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_grid_sampler(self):
        self.sampler = [[GridSampler(
            reader=reader,
            data_param=self.data_param,
            batch_size=self.net_param.batch_size,
            spatial_window_size=self.action_param.spatial_window_size,
            window_border=self.action_param.border,
            queue_length=self.net_param.queue_length) for reader in
            self.readers]]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        else:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=1,
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):

        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                loss_type=self.action_param.loss_type)

            crop_layer = CropLayer(
                border=self.regression_param.loss_border, name='crop-88')
            prediction = crop_layer(net_out)
            ground_truth = crop_layer(data_dict.get('output', None))
            weight_map = None if data_dict.get('weight', None) is None \
                else crop_layer(data_dict.get('weight', None))
            data_loss = loss_func(prediction=prediction,
                                  ground_truth=ground_truth,
                                  weight_map=weight_map)

            reg_losses = tf.get_collection(
                tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(
                var=data_loss, name='Loss',
                average_over_devices=False, collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=data_loss, name='Loss',
                average_over_devices=True, summary_type='scalar',
                collection=TF_SUMMARIES)
        else:
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_out = self.net(image, is_training=self.is_training)

            crop_layer = CropLayer(border=0, name='crop-88')
            post_process_layer = PostProcessingLayer('IDENTITY')
            net_out = post_process_layer(crop_layer(net_out))

            outputs_collector.add_to_collection(
                var=net_out, name='window',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'], name='location',
                average_over_devices=False, collection=NETWORK_OUTPUT)
            init_aggregator = \
                self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
            init_aggregator()

    def interpret_output(self, batch_output):
        if not self.is_training:
            return self.output_decoder.decode_batch(
                batch_output['window'], batch_output['location'])
        else:
            return True
class RegressionApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "REGRESSION"

    def __init__(self, net_param, action_param, action):
        BaseApplication.__init__(self)
        tf.logging.info('starting regression application')
        self.action = action

        self.net_param = net_param
        self.action_param = action_param
        self.regression_param = None

        self.data_param = None
        self.SUPPORTED_SAMPLING = {
            'uniform':
            (self.initialise_uniform_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'weighted':
            (self.initialise_weighted_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'resize':
            (self.initialise_resize_sampler, self.initialise_resize_sampler,
             self.initialise_resize_aggregator),
            'balanced':
            (self.initialise_balanced_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
        }

    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):
        self.data_param = data_param
        self.regression_param = task_param

        file_lists = self.get_file_lists(data_partitioner)
        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader({'image', 'output', 'weight', 'sampler'})
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        elif self.is_inference:
            inference_reader = ImageReader(['image'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [inference_reader]
        elif self.is_evaluation:
            file_list = data_partitioner.inference_files
            reader = ImageReader({'image', 'output', 'inferred'})
            reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [reader]
        else:
            raise ValueError(
                'Action `{}` not supported. Expected one of {}'.format(
                    self.action, self.SUPPORTED_ACTIONS))

        mean_var_normaliser = MeanVarNormalisationLayer(image_name='image')
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size,
                         mode=self.net_param.volume_padding_mode))
        for reader in self.readers:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers +
                                            augmentation_layers)

    def initialise_uniform_sampler(self):
        self.sampler = [[
            UniformSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_weighted_sampler(self):
        self.sampler = [[
            WeightedSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_resize_sampler(self):
        self.sampler = [[
            ResizeSampler(reader=reader,
                          data_param=self.data_param,
                          batch_size=self.net_param.batch_size,
                          shuffle_buffer=self.is_training,
                          queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_grid_sampler(self):
        self.sampler = [[
            GridSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                spatial_window_size=self.action_param.spatial_window_size,
                window_border=self.action_param.border,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_balanced_sampler(self):
        self.sampler = [[
            BalancedSampler(
                reader=reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
            for reader in self.readers
        ]]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order,
            postfix=self.action_param.output_postfix)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.readers[0],
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order,
            postfix=self.action_param.output_postfix)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        elif self.is_inference:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_aggregator(self):
        self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]()

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=1,
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        def switch_sampler(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0][0 if for_training else -1]
                return sampler.pop_batch_op()

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                data_dict = tf.cond(tf.logical_not(self.is_validation),
                                    lambda: switch_sampler(True),
                                    lambda: switch_sampler(False))
            else:
                data_dict = switch_sampler(for_training=True)

            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(loss_type=self.action_param.loss_type)

            crop_layer = CropLayer(border=self.regression_param.loss_border,
                                   name='crop-88')
            prediction = crop_layer(net_out)
            ground_truth = crop_layer(data_dict.get('output', None))
            weight_map = None if data_dict.get('weight', None) is None \
                else crop_layer(data_dict.get('weight', None))
            data_loss = loss_func(prediction=prediction,
                                  ground_truth=ground_truth,
                                  weight_map=weight_map)

            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)

            # Gradient Clipping associated with VDSR3D
            # Gradients are clipped by value, instead of clipping by global norm.
            # The authors of VDSR do not specify a threshold for the clipping process.
            # grads2, vars2 = zip(*grads)
            # grads2, _ = tf.clip_by_global_norm(grads2, 5.0)
            # grads = zip(grads2, vars2)
            grads = [(tf.clip_by_value(grad, -0.00001 / self.action_param.lr,
                                       +0.00001 / self.action_param.lr), val)
                     for grad, val in grads if grad is not None]

            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='Loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='Loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
        elif self.is_inference:
            data_dict = switch_sampler(for_training=False)
            image = tf.cast(data_dict['image'], tf.float32)
            net_args = {
                'is_training': self.is_training,
                'keep_prob': self.net_param.keep_prob
            }
            net_out = self.net(image, **net_args)

            crop_layer = CropLayer(border=0, name='crop-88')
            post_process_layer = PostProcessingLayer('IDENTITY')
            net_out = post_process_layer(crop_layer(net_out))

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            self.initialise_aggregator()

    def interpret_output(self, batch_output):
        if self.is_inference:
            return self.output_decoder.decode_batch(batch_output['window'],
                                                    batch_output['location'])
        else:
            return True

    def initialise_evaluator(self, eval_param):
        self.eval_param = eval_param
        self.evaluator = RegressionEvaluator(self.readers[0],
                                             self.regression_param, eval_param)

    def add_inferred_output(self, data_param, task_param):
        return self.add_inferred_output_like(data_param, task_param, 'output')
    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):

        def switch_samplers(for_training):
            with tf.name_scope('train' if for_training else 'validation'):
                sampler = self.get_sampler()[0 if for_training else -1]
                return sampler()  # returns image only

        if self.is_training:
            if self.action_param.validation_every_n > 0:
                sampler_window = \
                    tf.cond(tf.logical_not(self.is_validation),
                            lambda: switch_samplers(True),
                            lambda: switch_samplers(False))
            else:
                sampler_window = switch_samplers(True)

            image_windows, _ = sampler_window
            # image_windows, locations = sampler_window

            # decode channels for moving and fixed images
            image_windows_list = [
                tf.expand_dims(img, axis=-1)
                for img in tf.unstack(image_windows, axis=-1)]
            fixed_image, fixed_label, moving_image, moving_label = \
                image_windows_list

            # estimate ddf
            dense_field = self.net(fixed_image, moving_image)
            if isinstance(dense_field, tuple):
                dense_field = dense_field[0]

            # transform the moving labels
            resampler = ResamplerLayer(
                interpolation='linear', boundary='replicate')
            resampled_moving_label = resampler(moving_label, dense_field)

            # compute label loss (foreground only)
            loss_func = LossFunction(
                n_class=1,
                loss_type=self.action_param.loss_type,
                softmax=False)
            label_loss = loss_func(prediction=resampled_moving_label,
                                   ground_truth=fixed_label)

            dice_fg = 1.0 - label_loss
            # appending regularisation loss
            total_loss = label_loss
            reg_loss = tf.get_collection('bending_energy')
            if reg_loss:
                total_loss = total_loss + \
                    self.net_param.decay * tf.reduce_mean(reg_loss)

            # compute training gradients
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            grads = self.optimiser.compute_gradients(total_loss)
            gradients_collector.add_to_collection(grads)

            metrics_dice = loss_func(
                prediction=tf.to_float(resampled_moving_label >= 0.5),
                ground_truth=tf.to_float(fixed_label >= 0.5))
            metrics_dice = 1.0 - metrics_dice

            # command line output
            outputs_collector.add_to_collection(
                var=dice_fg, name='one_minus_data_loss',
                collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=tf.reduce_mean(reg_loss), name='bending_energy',
                collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=total_loss, name='total_loss', collection=CONSOLE)
            outputs_collector.add_to_collection(
                var=metrics_dice, name='ave_fg_dice', collection=CONSOLE)

            # for tensorboard
            outputs_collector.add_to_collection(
                var=dice_fg,
                name='data_loss',
                average_over_devices=True,
                summary_type='scalar',
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=total_loss,
                name='averaged_total_loss',
                average_over_devices=True,
                summary_type='scalar',
                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(
                var=metrics_dice,
                name='averaged_foreground_Dice',
                average_over_devices=True,
                summary_type='scalar',
                collection=TF_SUMMARIES)

            # for visualisation debugging
            # resampled_moving_image = resampler(moving_image, dense_field)
            # outputs_collector.add_to_collection(
            #     var=fixed_image, name='fixed_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=fixed_label, name='fixed_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=moving_image, name='moving_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=moving_label, name='moving_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=resampled_moving_image, name='resampled_image',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=resampled_moving_label, name='resampled_label',
            #     collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=dense_field, name='ddf', collection=NETWORK_OUTPUT)
            # outputs_collector.add_to_collection(
            #     var=locations, name='locations', collection=NETWORK_OUTPUT)

            # outputs_collector.add_to_collection(
            #     var=shift[0], name='a', collection=CONSOLE)
            # outputs_collector.add_to_collection(
            #     var=shift[1], name='b', collection=CONSOLE)
        else:
            image_windows, locations = self.sampler()
            image_windows_list = [
                tf.expand_dims(img, axis=-1)
                for img in tf.unstack(image_windows, axis=-1)]
            fixed_image, fixed_label, moving_image, moving_label = \
                image_windows_list

            dense_field = self.net(fixed_image, moving_image)
            if isinstance(dense_field, tuple):
                dense_field = dense_field[0]

            # transform the moving labels
            resampler = ResamplerLayer(
                interpolation='linear', boundary='replicate')
            resampled_moving_image = resampler(moving_image, dense_field)
            resampled_moving_label = resampler(moving_label, dense_field)

            outputs_collector.add_to_collection(
                var=fixed_image, name='fixed_image',
                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=moving_image, name='moving_image',
                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=resampled_moving_image,
                name='resampled_moving_image',
                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=resampled_moving_label,
                name='resampled_moving_label',
                collection=NETWORK_OUTPUT)

            outputs_collector.add_to_collection(
                var=fixed_label, name='fixed_label',
                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=moving_label, name='moving_label',
                collection=NETWORK_OUTPUT)
            #outputs_collector.add_to_collection(
            #    var=dense_field, name='field',
            #    collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=locations, name='locations',
                collection=NETWORK_OUTPUT)

            self.output_decoder = ResizeSamplesAggregator(
                image_reader=self.readers[0], # fixed image reader
                name='fixed_image',
                output_path=self.action_param.save_seg_dir,
                interp_order=self.action_param.output_interp_order)