def get_label_reader():
    reader = ImageReader(['label'])
    reader.initialise(MOD_LABEL_DATA, MOD_LABEl_TASK, mod_label_list)
    label_normaliser = DiscreteLabelNormalisationLayer(
        image_name='label',
        modalities=vars(SINGLE_25D_TASK).get('label'),
        model_filename=os.path.join('testing_data', 'agg_test.txt'))
    reader.add_preprocessing_layers(label_normaliser)
    pad_layer = PadLayer(image_name=('label',), border=(5, 6, 7))
    reader.add_preprocessing_layers([pad_layer])
    return reader
Exemplo n.º 2
0
 def test_preprocessing_zero_padding(self):
     reader = ImageReader(['image'])
     reader.initialise(SINGLE_MOD_DATA, SINGLE_MOD_TASK, single_mod_list)
     idx, data, interp_order = reader()
     self.assertEqual(SINGLE_MOD_DATA['lesion'].interp_order,
                      interp_order['image'][0])
     self.assertAllClose(data['image'].shape, (256, 168, 256, 1, 1))
     reader.add_preprocessing_layers(
         [PadLayer(image_name=['image'], border=(0, 0, 0))])
     idx, data, interp_order = reader(idx=2)
     self.assertEqual(idx, 2)
     self.assertAllClose(data['image'].shape, (256, 168, 256, 1, 1))
Exemplo n.º 3
0
    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):
        self.data_param = data_param
        self.segmentation_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            file_lists = []
            if self.action_param.validation_every_n > 0:
                file_lists.append(data_partitioner.train_files)
                file_lists.append(data_partitioner.validation_files)
            else:
                file_lists.append(data_partitioner.all_files)
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(SUPPORTED_INPUT)
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        else:  # in the inference process use image input only
            inference_reader = ImageReader(['image'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)

        label_normaliser = DiscreteLabelNormalisationLayer(
            image_name='label',
            modalities=vars(task_param).get('label'),
            model_filename=self.net_param.histogram_ref_file)

        normalisation_layers = []
        normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation:
            normalisation_layers.append(label_normaliser)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(PadLayer(
                image_name=SUPPORTED_INPUT,
                border=self.net_param.volume_padding_size))
        for reader in self.readers:
            reader.add_preprocessing_layers(
                normalisation_layers + volume_padding_layer)
Exemplo n.º 4
0
 def test_trainable_preprocessing(self):
     label_file = os.path.join('testing_data', 'label_reader.txt')
     if os.path.exists(label_file):
         os.remove(label_file)
     label_normaliser = DiscreteLabelNormalisationLayer(
         image_name='label',
         modalities=vars(LABEL_TASK).get('label'),
         model_filename=os.path.join('testing_data', 'label_reader.txt'))
     reader = ImageReader(['label'])
     with self.assertRaisesRegexp(AssertionError, ''):
         reader.add_preprocessing_layers(label_normaliser)
     reader.initialise(LABEL_DATA, LABEL_TASK, label_list)
     reader.add_preprocessing_layers(label_normaliser)
     reader.add_preprocessing_layers(
         [PadLayer(image_name=['label'], border=(10, 5, 5))])
     idx, data, interp_order = reader(idx=0)
     unique_data = np.unique(data['label'])
     expected_v1 = np.array(
         [0., 1., 2., 3., 4., 5., 6., 7., 8.,
          9., 10., 11., 12., 13., 14., 15., 16., 17.,
          18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
          28., 29., 30., 31., 32., 33., 34., 35., 36.,
          37., 38., 39., 40., 41., 42., 43., 44., 45.,
          46., 47., 48., 49., 50., 51., 52., 53., 54.,
          55., 56., 57., 58., 59., 60., 61., 62., 63.,
          64., 65., 66., 67., 68., 69., 70., 71., 72.,
          73., 74., 75., 76., 77., 78., 79., 80., 81.,
          82., 83., 84., 85., 86., 87., 88., 89., 90.,
          91., 92., 93., 94., 95., 96., 97., 98., 99.,
          100., 101., 102., 103., 104., 105., 106., 107., 108.,
          109., 110., 111., 112., 113., 114., 115., 116., 117.,
          118., 119., 120., 121., 122., 123., 124., 125., 126.,
          127., 128., 129., 130., 131., 132., 133., 134., 135.,
          136., 137., 138., 139., 140., 141., 142., 143., 144.,
          145., 146., 147., 148., 149., 150., 151., 152., 153.,
          154., 155., 156., 157.], dtype=np.float32)
     expected_v2 = np.array(
         [0., 1., 2., 3., 4., 5., 6., 7., 8.,
          9., 10., 11., 12., 13., 14., 15., 16., 17.,
          18., 20., 21., 22., 23., 24., 25., 26., 27.,
          28., 29., 30., 31., 32., 33., 34., 35., 36.,
          37., 38., 39., 40., 41., 42., 43., 44., 45.,
          46., 47., 48., 49., 50., 51., 52., 53., 54.,
          55., 56., 57., 58., 59., 60., 61., 62., 63.,
          64., 65., 66., 67., 68., 69., 70., 71., 72.,
          73., 74., 75., 76., 77., 78., 79., 80., 81.,
          82., 83., 84., 85., 86., 87., 88., 89., 90.,
          91., 92., 93., 94., 95., 96., 97., 98., 99.,
          100., 101., 102., 103., 104., 105., 106., 107., 108.,
          109., 110., 111., 112., 113., 114., 115., 116., 117.,
          118., 119., 120., 121., 122., 123., 124., 125., 126.,
          127., 128., 129., 130., 131., 132., 133., 134., 135.,
          136., 137., 138., 139., 140., 141., 142., 143., 144.,
          145., 146., 147., 148., 149., 150., 151., 152., 153.,
          154., 155., 156., 157.], dtype=np.float32)
     compatible_assert = \
         np.all(unique_data == expected_v1) or \
         np.all(unique_data == expected_v2)
     self.assertTrue(compatible_assert)
     self.assertAllClose(data['label'].shape, (103, 74, 93, 1, 1))
Exemplo n.º 5
0
    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):
        self.data_param = data_param
        self.regression_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            file_lists = []
            if self.action_param.validation_every_n > 0:
                file_lists.append(data_partitioner.train_files)
                file_lists.append(data_partitioner.validation_files)
            else:
                file_lists.append(data_partitioner.train_files)

            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(SUPPORTED_INPUT)
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        else:
            inference_reader = ImageReader(['image'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image')
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=self.action_param.scaling_percentage[0],
                    max_percentage=self.action_param.scaling_percentage[1]))
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(PadLayer(
                image_name=SUPPORTED_INPUT,
                border=self.net_param.volume_padding_size))
        for reader in self.readers:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers +
                                            augmentation_layers)
    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):

        self.data_param = data_param
        self.classification_param = task_param

        file_lists = self.get_file_lists(data_partitioner)
        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(['image', 'label', 'sampler'])
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)

        elif self.is_inference:  
            # in the inference process use image input only
            inference_reader = ImageReader(['image'])
            inference_reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [inference_reader]
        elif self.is_evaluation:
            reader = ImageReader({'image', 'label', 'inferred'})
            reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [reader]
        else:
            raise ValueError('Action `{}` not supported. Expected one of {}'
                             .format(self.action, self.SUPPORTED_ACTIONS))

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        label_normaliser = None
        if self.net_param.histogram_ref_file:
            label_normaliser = DiscreteLabelNormalisationLayer(
                image_name='label',
                modalities=vars(task_param).get('label'),
                model_filename=self.net_param.histogram_ref_file)

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation:
            normalisation_layers.append(label_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=self.action_param.scaling_percentage[0],
                    max_percentage=self.action_param.scaling_percentage[1]))
            if self.action_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if self.action_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        self.action_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)

        for reader in self.readers:
            reader.add_preprocessing_layers(
                normalisation_layers +
                augmentation_layers)
class SegmentationApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "SEGMENTATION"

    def __init__(self, net_param, action_param, is_training):
        BaseApplication.__init__(self)
        tf.logging.info('starting segmentation application')
        self.is_training = is_training

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.segmentation_param = None
        self.SUPPORTED_SAMPLING = {
            'uniform':
            (self.initialise_uniform_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'resize':
            (self.initialise_resize_sampler, self.initialise_resize_sampler,
             self.initialise_resize_aggregator),
        }

    def initialise_dataset_loader(self, data_param=None, task_param=None):
        self.data_param = data_param
        self.segmentation_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.reader = ImageReader(SUPPORTED_INPUT)
        else:  # in the inference process use image input only
            self.reader = ImageReader(['image'])
        self.reader.initialise_reader(data_param, task_param)

        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)
        else:
            foreground_masking_layer = None

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')
        else:
            histogram_normaliser = None

        if self.net_param.histogram_ref_file:
            label_normaliser = DiscreteLabelNormalisationLayer(
                image_name='label',
                modalities=vars(task_param).get('label'),
                model_filename=self.net_param.histogram_ref_file)
        else:
            label_normaliser = None

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation:
            normalisation_layers.append(label_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.rotation_angle:
                rotation_layer = RandomRotationLayer()
                if self.action_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        self.action_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)
# ========================== Disable scaling and rotation =====================
#             if self.action_param.scaling_percentage:
#                 augmentation_layers.append(RandomSpatialScalingLayer(
#                     min_percentage=self.action_param.scaling_percentage[0],
#                     max_percentage=self.action_param.scaling_percentage[1]))
# =============================================================================
# =============================================================================
#             if self.action_param.rotation_angle:
#                 augmentation_layers.append(RandomRotationLayer(
#                     min_angle=self.action_param.rotation_angle[0],
#                     max_angle=self.action_param.rotation_angle[1]))
# =============================================================================

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size))
        self.reader.add_preprocessing_layers(volume_padding_layer +
                                             normalisation_layers +
                                             augmentation_layers)

    def initialise_uniform_sampler(self):
        self.sampler = [
            UniformSampler(
                reader=self.reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
        ]

    def initialise_resize_sampler(self):
        self.sampler = [
            ResizeSampler(reader=self.reader,
                          data_param=self.data_param,
                          batch_size=self.net_param.batch_size,
                          shuffle_buffer=self.is_training,
                          queue_length=self.net_param.queue_length)
        ]

    def initialise_grid_sampler(self):
        self.sampler = [
            GridSampler(
                reader=self.reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                spatial_window_size=self.action_param.spatial_window_size,
                window_border=self.action_param.border,
                queue_length=self.net_param.queue_length)
        ]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.reader,
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.reader,
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        else:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_network(self):
        num_classes = self.segmentation_param.num_classes
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=num_classes,
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        data_dict = self.get_sampler()[0].pop_batch_op()
        image = tf.cast(data_dict['image'], tf.float32)
        net_out = self.net(image, self.is_training)

        if self.is_training:

            label = data_dict.get('label', None)

            # Changed label on 11/29/2017: This will generate a 2D label
            # from the 3D label provided in the input. Only suitable for STNeuroNet
            k = label.get_shape().as_list()
            label = tf.nn.max_pool3d(label, [1, 1, 1, k[3], 1],
                                     [1, 1, 1, 1, 1],
                                     'VALID',
                                     data_format='NDHWC')
            print('label shape is{}'.format(label.get_shape()))
            print('Image shape is{}'.format(image.get_shape()))
            print('Out shape is{}'.format(net_out.get_shape()))
            ####
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(
                n_class=self.segmentation_param.num_classes,
                loss_type=self.action_param.loss_type)
            data_loss = loss_func(prediction=net_out,
                                  ground_truth=label,
                                  weight_map=data_dict.get('weight', None))
            if self.net_param.decay > 0.0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                    loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='dice_loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='dice_loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
            # ADDED on 10/30 by Soltanian-Zadeh for tensorboard visualization
            seg_summary = tf.to_float(
                tf.expand_dims(tf.argmax(net_out, -1), -1)) * (
                    255. / self.segmentation_param.num_classes - 1)
            label_summary = tf.to_float(tf.expand_dims(
                label, -1)) * (255. / self.segmentation_param.num_classes - 1)
            m, v = tf.nn.moments(image, axes=[1, 2, 3], keep_dims=True)
            img_summary = tf.minimum(
                255.,
                tf.maximum(0., (tf.to_float(image - m) /
                                (tf.sqrt(v) * 2.) + 1.) * 127.))
            image3_axial('img', img_summary, 50, [tf.GraphKeys.SUMMARIES])
            image3_axial('seg', seg_summary, 5, [tf.GraphKeys.SUMMARIES])
            image3_axial('label', label_summary, 5, [tf.GraphKeys.SUMMARIES])
        else:
            # converting logits into final output for
            # classification probabilities or argmax classification labels
            output_prob = self.segmentation_param.output_prob
            num_classes = self.segmentation_param.num_classes
            if output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'SOFTMAX', num_classes=num_classes)
            elif not output_prob and num_classes > 1:
                post_process_layer = PostProcessingLayer(
                    'ARGMAX', num_classes=num_classes)
            else:
                post_process_layer = PostProcessingLayer(
                    'IDENTITY', num_classes=num_classes)
            net_out = post_process_layer(net_out)
            print('output shape is{}'.format(net_out.get_shape()))

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETORK_OUTPUT)
            init_aggregator = \
                self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
            init_aggregator()

    def interpret_output(self, batch_output):
        if not self.is_training:
            return self.output_decoder.decode_batch(batch_output['window'],
                                                    batch_output['location'])
        return True
Exemplo n.º 8
0
    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):

        self.data_param = data_param
        self.segmentation_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            file_lists = []
            if self.action_param.validation_every_n > 0:
                file_lists.append(data_partitioner.train_files)
                file_lists.append(data_partitioner.validation_files)
            else:
                file_lists.append(data_partitioner.train_files)

            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(SUPPORTED_INPUT)
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)

        else:  # in the inference process use image input only
            inference_reader = ImageReader(['image'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        label_normaliser = None
        if self.net_param.histogram_ref_file:
            label_normaliser = DiscreteLabelNormalisationLayer(
                image_name='label',
                modalities=vars(task_param).get('label'),
                model_filename=self.net_param.histogram_ref_file)

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation and \
                (self.is_training or not task_param.output_prob):
            normalisation_layers.append(label_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if self.action_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        self.action_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size))

        for reader in self.readers:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers +
                                            augmentation_layers)
Exemplo n.º 9
0
    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):
        self.data_param = data_param
        self.gan_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            file_lists = []
            if self.action_param.validation_every_n > 0:
                file_lists.append(data_partitioner.train_files)
                file_lists.append(data_partitioner.validation_files)
            else:
                file_lists.append(data_partitioner.train_files)
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(['image', 'conditioning'])
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        else:
            inference_reader = ImageReader(['conditioning'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        for reader in self.readers:
            reader.add_preprocessing_layers(normalisation_layers +
                                            augmentation_layers)
Exemplo n.º 10
0
}

#Read images and plot
#Original images are already interpolated (order 1) to 1mm pixel size.
idx = 0
reader = ImageReader().initialise(data_param)
_, vol, _ = reader(idx=idx)
plot_slices(vol['MR'], vol['CT'], vol['LABELS'], idx, title='Original images')
#sitk.WriteImage(sitk.GetImageFromArray(vol['MR'][:,:,:,0,0]),'./original.mhd') #save to disk
print(vol['MR'].shape)

##Normalization
binary_masking_func = Mask(type_str='mean_plus')
#mask=binary_masking_func(vol['MR'])
reader.add_preprocessing_layers(
    MeanVarNormalisationLayer(image_name='MR',
                              binary_masking_func=binary_masking_func))
_, vol, _ = reader(idx=idx)
plot_slices(vol['MR'],
            vol['CT'],
            vol['LABELS'],
            idx,
            title='Normalized MR image')
#sitk.WriteImage(sitk.GetImageFromArray(vol['MR'][:,:,:,0,0]),'./whitenned.mhd')

# Use border as half of spatial_window_size, so the NN can learn more from the borders
#Use constant mode padding, I have modified niftynet.layers.pad to use constant_values=np.amin(image)
reader.add_preprocessing_layers(
    PadLayer(image_name=['MR', 'CT', 'LABELS'],
             border=(24, 24, 48),
             mode='constant'))
Exemplo n.º 11
0
        'axcodes': ['L', 'P', 'S'],
        'interp_order': 0
    },
}
#This option is to read 2D slices
# data_param = {'MR': {'path_to_search': './sample_images','spatial_window_size': (48,48,41),
#             'filename_contains': ['middleSlice','mhd'], 'pixdim': (1,1), 'axcodes': ['L','P','S'],'interp_order': 3},
#             'sampler': {'path_to_search': './sample_images', 'spatial_window_size': (48,48,1),
#             'filename_contains': ['middleLabels','mhd'],'pixdim': (1,1), 'axcodes': ['L','P','S'],'interp_order': 0},
#               }

#Create image reader and add padding layer with 'constant' value 0
#This will ensure that samples are taken even at the beginning and ending slice
reader = ImageReader().initialise(data_param)
reader.add_preprocessing_layers(
    PadLayer(image_name=['MR', 'sampler'],
             border=(20, 20, 20),
             mode='constant'))
_, img, _ = reader(idx=0)

#Create samplers with window_size
weighted_sampler = WeightedSampler(reader, window_sizes=(48, 48, 48))

balanced_sampler = BalancedSampler(reader, window_sizes=(48, 48, 48))

uniform_sampler = UniformSampler(reader, window_sizes=(48, 48, 48))

#Generate N samples for each type
N = 30
import tensorflow as tf
# adding the tensorflow tensors
next_window = weighted_sampler.pop_batch_op()
Exemplo n.º 12
0
class RegressionApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "REGRESSION"

    def __init__(self, net_param, action_param, is_training):
        BaseApplication.__init__(self)
        tf.logging.info('starting regression application')
        self.is_training = is_training

        self.net_param = net_param
        self.action_param = action_param
        self.regression_param = None

        self.data_param = None
        self.SUPPORTED_SAMPLING = {
            'uniform':
            (self.initialise_uniform_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'weighted':
            (self.initialise_weighted_sampler, self.initialise_grid_sampler,
             self.initialise_grid_aggregator),
            'resize':
            (self.initialise_resize_sampler, self.initialise_resize_sampler,
             self.initialise_resize_aggregator),
        }

    def initialise_dataset_loader(self, data_param=None, task_param=None):
        self.data_param = data_param
        self.regression_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.reader = ImageReader(SUPPORTED_INPUT)
        else:  # in the inference process use image input only
            self.reader = ImageReader(['image'])
        self.reader.initialise_reader(data_param, task_param)

        mean_var_normaliser = MeanVarNormalisationLayer(image_name='image')
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')
        else:
            histogram_normaliser = None

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size))
        self.reader.add_preprocessing_layers(volume_padding_layer +
                                             normalisation_layers +
                                             augmentation_layers)

    def initialise_uniform_sampler(self):
        self.sampler = [
            UniformSampler(
                reader=self.reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
        ]

    def initialise_weighted_sampler(self):
        self.sampler = [
            WeightedSampler(
                reader=self.reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                windows_per_image=self.action_param.sample_per_volume,
                queue_length=self.net_param.queue_length)
        ]

    def initialise_resize_sampler(self):
        self.sampler = [
            ResizeSampler(reader=self.reader,
                          data_param=self.data_param,
                          batch_size=self.net_param.batch_size,
                          shuffle_buffer=self.is_training,
                          queue_length=self.net_param.queue_length)
        ]

    def initialise_grid_sampler(self):
        self.sampler = [
            GridSampler(
                reader=self.reader,
                data_param=self.data_param,
                batch_size=self.net_param.batch_size,
                spatial_window_size=self.action_param.spatial_window_size,
                window_border=self.action_param.border,
                queue_length=self.net_param.queue_length)
        ]

    def initialise_grid_aggregator(self):
        self.output_decoder = GridSamplesAggregator(
            image_reader=self.reader,
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_resize_aggregator(self):
        self.output_decoder = ResizeSamplesAggregator(
            image_reader=self.reader,
            output_path=self.action_param.save_seg_dir,
            window_border=self.action_param.border,
            interp_order=self.action_param.output_interp_order)

    def initialise_sampler(self):
        if self.is_training:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
        else:
            self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()

    def initialise_network(self):
        w_regularizer = None
        b_regularizer = None
        reg_type = self.net_param.reg_type.lower()
        decay = self.net_param.decay
        if reg_type == 'l2' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l2_regularizer(decay)
            b_regularizer = regularizers.l2_regularizer(decay)
        elif reg_type == 'l1' and decay > 0:
            from tensorflow.contrib.layers.python.layers import regularizers
            w_regularizer = regularizers.l1_regularizer(decay)
            b_regularizer = regularizers.l1_regularizer(decay)

        self.net = ApplicationNetFactory.create(self.net_param.name)(
            num_classes=1,
            w_regularizer=w_regularizer,
            b_regularizer=b_regularizer,
            acti_func=self.net_param.activation_function)

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        data_dict = self.get_sampler()[0].pop_batch_op()
        image = tf.cast(data_dict['image'], tf.float32)
        net_out = self.net(image, self.is_training)

        if self.is_training:
            crop_layer = CropLayer(border=self.regression_param.loss_border,
                                   name='crop-88')
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)
            loss_func = LossFunction(loss_type=self.action_param.loss_type)

            prediction = crop_layer(net_out)
            ground_truth = crop_layer(data_dict.get('output', None))
            weight_map = None if data_dict.get('weight', None) is None \
                else crop_layer(data_dict.get('weight', None))
            data_loss = loss_func(prediction=prediction,
                                  ground_truth=ground_truth,
                                  weight_map=weight_map)

            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            if self.net_param.decay > 0.0 and reg_losses:
                reg_loss = tf.reduce_mean(
                    [tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
                loss = data_loss + reg_loss
            else:
                loss = data_loss
            grads = self.optimiser.compute_gradients(loss)
            # collecting gradients variables
            gradients_collector.add_to_collection([grads])
            # collecting output variables
            outputs_collector.add_to_collection(var=data_loss,
                                                name='Loss',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=data_loss,
                                                name='Loss',
                                                average_over_devices=True,
                                                summary_type='scalar',
                                                collection=TF_SUMMARIES)
        else:
            crop_layer = CropLayer(border=0, name='crop-88')
            post_process_layer = PostProcessingLayer('IDENTITY')
            net_out = post_process_layer(crop_layer(net_out))

            outputs_collector.add_to_collection(var=net_out,
                                                name='window',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=data_dict['image_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)
            init_aggregator = \
                self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
            init_aggregator()

    def interpret_output(self, batch_output):
        if not self.is_training:
            return self.output_decoder.decode_batch(batch_output['window'],
                                                    batch_output['location'])
        else:
            return True
Exemplo n.º 13
0
    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):

        self.data_param = data_param
        self.segmentation_param = task_param

        file_lists = self.get_file_lists(data_partitioner)
        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader({'image', 'label', 'weight', 'sampler'})
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)

        elif self.is_inference:
            # in the inference process use image input only
            inference_reader = ImageReader({'image'})
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]
        elif self.is_evaluation:
            file_list = data_partitioner.inference_files
            reader = ImageReader({'image', 'label', 'inferred'})
            reader.initialise(data_param, task_param, file_list)
            self.readers = [reader]
        else:
            raise ValueError('Action `{}` not supported. Expected one of {}'
                             .format(self.action, self.SUPPORTED_ACTIONS))

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        label_normalisers = None
        if self.net_param.histogram_ref_file and \
                task_param.label_normalisation:
            label_normalisers = [DiscreteLabelNormalisationLayer(
                image_name='label',
                modalities=vars(task_param).get('label'),
                model_filename=self.net_param.histogram_ref_file)]
            if self.is_evaluation:
                label_normalisers.append(
                    DiscreteLabelNormalisationLayer(
                        image_name='inferred',
                        modalities=vars(task_param).get('inferred'),
                        model_filename=self.net_param.histogram_ref_file))
                label_normalisers[-1].key = label_normalisers[0].key

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation and \
                (self.is_training or not task_param.output_prob):
            normalisation_layers.extend(label_normalisers)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=self.action_param.scaling_percentage[0],
                    max_percentage=self.action_param.scaling_percentage[1]))
            if self.action_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if self.action_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        self.action_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)

            # add deformation layer
            if self.action_param.do_elastic_deformation:
                spatial_rank = list(self.readers[0].spatial_ranks.values())[0]
                augmentation_layers.append(RandomElasticDeformationLayer(
                    spatial_rank=spatial_rank,
                    num_controlpoints=self.action_param.num_ctrl_points,
                    std_deformation_sigma=self.action_param.deformation_sigma,
                    proportion_to_augment=self.action_param.proportion_to_deform))

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(PadLayer(
                image_name=SUPPORTED_INPUT,
                border=self.net_param.volume_padding_size))

        # only add augmentation to first reader (not validation reader)
        self.readers[0].add_preprocessing_layers(
            volume_padding_layer +
            normalisation_layers +
            augmentation_layers)

        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(
                volume_padding_layer +
                normalisation_layers)
Exemplo n.º 14
0
    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):
        self.data_param = data_param
        self.regression_param = task_param

        file_lists = self.get_file_lists(data_partitioner)
        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader({'image', 'output', 'weight', 'sampler'})
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        elif self.is_inference:
            inference_reader = ImageReader(['image'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [inference_reader]
        elif self.is_evaluation:
            file_list = data_partitioner.inference_files
            reader = ImageReader({'image', 'output', 'inferred'})
            reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [reader]
        else:
            raise ValueError('Action `{}` not supported. Expected one of {}'
                             .format(self.action, self.SUPPORTED_ACTIONS))

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image')
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=self.action_param.scaling_percentage[0],
                    max_percentage=self.action_param.scaling_percentage[1]))
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(PadLayer(
                image_name=SUPPORTED_INPUT,
                border=self.net_param.volume_padding_size))
        for reader in self.readers:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers +
                                            augmentation_layers)
Exemplo n.º 15
0
class GANApplication(BaseApplication):
    REQUIRED_CONFIG_SECTION = "GAN"

    def __init__(self, net_param, action_param, is_training):
        BaseApplication.__init__(self)
        tf.logging.info('starting GAN application')
        self.is_training = is_training

        self.net_param = net_param
        self.action_param = action_param

        self.data_param = None
        self.gan_param = None

    def initialise_dataset_loader(self, data_param=None, task_param=None):
        self.data_param = data_param
        self.gan_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.reader = ImageReader(['image', 'conditioning'])
        else:  # in the inference process use image input only
            self.reader = ImageReader(['conditioning'])
        if self.reader:
            self.reader.initialise_reader(data_param, task_param)

        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)
        else:
            foreground_masking_layer = None

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')
        else:
            histogram_normaliser = None

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        if self.reader:
            self.reader.add_preprocessing_layers(normalisation_layers +
                                                 augmentation_layers)

    def initialise_sampler(self):
        self.sampler = []
        if self.is_training:
            self.sampler.append(
                ResizeSampler(reader=self.reader,
                              data_param=self.data_param,
                              batch_size=self.net_param.batch_size,
                              windows_per_image=1,
                              shuffle_buffer=True,
                              queue_length=self.net_param.queue_length))
        else:
            self.sampler.append(
                RandomVectorSampler(
                    names=('vector', ),
                    vector_size=(self.gan_param.noise_size, ),
                    batch_size=self.net_param.batch_size,
                    n_interpolations=self.gan_param.n_interpolations,
                    repeat=None,
                    queue_length=self.net_param.queue_length))
            # repeat each resized image n times, so that each
            # image matches one random vector,
            # (n = self.gan_param.n_interpolations)
            self.sampler.append(
                ResizeSampler(
                    reader=self.reader,
                    data_param=self.data_param,
                    batch_size=self.net_param.batch_size,
                    windows_per_image=self.gan_param.n_interpolations,
                    shuffle_buffer=False,
                    queue_length=self.net_param.queue_length))

    def initialise_network(self):
        self.net = ApplicationNetFactory.create(self.net_param.name)()

    def connect_data_and_network(self,
                                 outputs_collector=None,
                                 gradients_collector=None):
        if self.is_training:
            with tf.name_scope('Optimiser'):
                optimiser_class = OptimiserFactory.create(
                    name=self.action_param.optimiser)
                self.optimiser = optimiser_class.get_instance(
                    learning_rate=self.action_param.lr)

            # a new pop_batch_op for each gpu tower
            data_dict = self.get_sampler()[0].pop_batch_op()

            images = tf.cast(data_dict['image'], tf.float32)
            noise_shape = [
                self.net_param.batch_size, self.gan_param.noise_size
            ]
            noise = tf.Variable(
                tf.random_normal(shape=noise_shape,
                                 mean=0.0,
                                 stddev=1.0,
                                 dtype=tf.float32))
            tf.stop_gradient(noise)
            conditioning = data_dict['conditioning']
            net_output = self.net(noise, images, conditioning,
                                  self.is_training)
            loss_func = LossFunction(loss_type=self.action_param.loss_type)
            real_logits = net_output[1]
            fake_logits = net_output[2]
            lossG, lossD = loss_func(real_logits, fake_logits)
            if self.net_param.decay > 0:
                reg_losses = tf.get_collection(
                    tf.GraphKeys.REGULARIZATION_LOSSES)
                if reg_losses:
                    reg_loss = tf.reduce_mean(
                        [tf.reduce_mean(l_reg) for l_reg in reg_losses])
                    lossD = lossD + reg_loss
                    lossG = lossG + reg_loss

            # variables to display in STDOUT
            outputs_collector.add_to_collection(var=lossD,
                                                name='lossD',
                                                average_over_devices=True,
                                                collection=CONSOLE)
            outputs_collector.add_to_collection(var=lossG,
                                                name='lossG',
                                                average_over_devices=False,
                                                collection=CONSOLE)
            # variables to display in tensorboard
            outputs_collector.add_to_collection(var=lossG,
                                                name='lossG',
                                                average_over_devices=False,
                                                collection=TF_SUMMARIES)
            outputs_collector.add_to_collection(var=lossG,
                                                name='lossD',
                                                average_over_devices=True,
                                                collection=TF_SUMMARIES)

            with tf.name_scope('ComputeGradients'):
                # gradients of generator
                generator_variables = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
                generator_grads = self.optimiser.compute_gradients(
                    lossG, var_list=generator_variables)

                # gradients of discriminator
                discriminator_variables = tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
                discriminator_grads = self.optimiser.compute_gradients(
                    lossD, var_list=discriminator_variables)
                grads = [generator_grads, discriminator_grads]

                # add the grads back to application_driver's training_grads
                gradients_collector.add_to_collection(grads)
        else:
            data_dict = self.get_sampler()[0].pop_batch_op()
            conditioning_dict = self.get_sampler()[1].pop_batch_op()
            conditioning = conditioning_dict['conditioning']
            image_size = conditioning.shape.as_list()[:-1]
            dummy_image = tf.zeros(image_size + [1])
            net_output = self.net(data_dict['vector'], dummy_image,
                                  conditioning, self.is_training)
            outputs_collector.add_to_collection(var=net_output[0],
                                                name='image',
                                                average_over_devices=False,
                                                collection=NETWORK_OUTPUT)
            outputs_collector.add_to_collection(
                var=conditioning_dict['conditioning_location'],
                name='location',
                average_over_devices=False,
                collection=NETWORK_OUTPUT)

            self.output_decoder = WindowAsImageAggregator(
                image_reader=self.reader,
                output_path=self.action_param.save_seg_dir)

    def interpret_output(self, batch_output):
        if self.is_training:
            return True
        return self.output_decoder.decode_batch(batch_output['image'],
                                                batch_output['location'])
Exemplo n.º 16
0
    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):
        self.data_param = data_param
        self.gan_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            file_lists = []
            if self.action_param.validation_every_n > 0:
                file_lists.append(data_partitioner.train_files)
                file_lists.append(data_partitioner.validation_files)
            else:
                file_lists.append(data_partitioner.train_files)
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(['image', 'conditioning'])
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        else:
            inference_reader = ImageReader(['conditioning'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image',
            binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=self.action_param.scaling_percentage[0],
                    max_percentage=self.action_param.scaling_percentage[1]))
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        for reader in self.readers:
            reader.add_preprocessing_layers(
                normalisation_layers + augmentation_layers)
Exemplo n.º 17
0
    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):
        self.data_param = data_param
        self.regression_param = task_param

        file_lists = self.get_file_lists(data_partitioner)
        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader({'image', 'output', 'weight', 'sampler'})
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        elif self.is_inference:
            inference_reader = ImageReader(['image'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [inference_reader]
        elif self.is_evaluation:
            file_list = data_partitioner.inference_files
            reader = ImageReader({'image', 'output', 'inferred'})
            reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [reader]
        else:
            raise ValueError('Action `{}` not supported. Expected one of {}'
                             .format(self.action, self.SUPPORTED_ACTIONS))

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image')
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        psf_layers = []
        psf = PSFLayer(image_name='image')
        psf_layers.append(psf)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=self.action_param.scaling_percentage[0],
                    max_percentage=self.action_param.scaling_percentage[1]))
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(PadLayer(
                image_name=SUPPORTED_INPUT,
                border=self.net_param.volume_padding_size,
                mode=self.net_param.volume_padding_mode))
        for reader in self.readers:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers + 
                                            augmentation_layers + 
                                            psf_layers)
Exemplo n.º 18
0
    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):

        self.data_param = data_param
        self.segmentation_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            file_lists = []
            if self.action_param.validation_every_n > 0:
                file_lists.append(data_partitioner.train_files)
                file_lists.append(data_partitioner.validation_files)
            else:
                file_lists.append(data_partitioner.train_files)

            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(SUPPORTED_INPUT)
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)

        else:  # in the inference process use image input only
            inference_reader = ImageReader(['image'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        label_normaliser = None
        if self.net_param.histogram_ref_file:
            label_normaliser = DiscreteLabelNormalisationLayer(
                image_name='label',
                modalities=vars(task_param).get('label'),
                model_filename=self.net_param.histogram_ref_file)

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation:
            normalisation_layers.append(label_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=self.action_param.scaling_percentage[0],
                    max_percentage=self.action_param.scaling_percentage[1]))
            if self.action_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if self.action_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        self.action_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(PadLayer(
                image_name=SUPPORTED_INPUT,
                border=self.net_param.volume_padding_size))

        for reader in self.readers:
            reader.add_preprocessing_layers(
                volume_padding_layer +
                normalisation_layers +
                augmentation_layers)