Example #1
0
    def initialise_dataset_loader(self, data_param=None, task_param=None):
        self.data_param = data_param
        self.gan_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.reader = ImageReader(['image', 'conditioning'])
        else:  # in the inference process use image input only
            self.reader = ImageReader(['conditioning'])
        if self.reader:
            self.reader.initialise_reader(data_param, task_param)

        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)
        else:
            foreground_masking_layer = None

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')
        else:
            histogram_normaliser = None

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        if self.reader:
            self.reader.add_preprocessing_layers(normalisation_layers +
                                                 augmentation_layers)
    def prepare_normalisation_layers(self):
        """
        returns list of normalisation layers
        """
        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(self.task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')
        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        return normalisation_layers
    def initialise_dataset_loader(self, data_param=None, task_param=None):
        self.data_param = data_param
        self.regression_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.reader = ImageReader(SUPPORTED_INPUT)
        else:  # in the inference process use image input only
            self.reader = ImageReader(['image'])
        self.reader.initialise_reader(data_param, task_param)

        mean_var_normaliser = MeanVarNormalisationLayer(image_name='image')
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')
        else:
            histogram_normaliser = None

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size))
        self.reader.add_preprocessing_layers(volume_padding_layer +
                                             normalisation_layers +
                                             augmentation_layers)
    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):
        RegressionApplication.initialise_dataset_loader(
            self, data_param, task_param, data_partitioner)
        if self.is_training:
            return
        if not task_param.error_map:
            # use the regression application implementation
            return

        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(phase=reader_phase,
                                                        action=self.action)
        # modifying the original readers in regression application
        # as we need ground truth labels to generate error maps
        self.readers = [
            ImageReader(['image',
                         'output']).initialise(data_param, task_param,
                                               file_list)
            for file_list in file_lists
        ]

        mean_var_normaliser = MeanVarNormalisationLayer(image_name='image')
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        preprocessors = []
        if self.net_param.normalisation:
            preprocessors.append(histogram_normaliser)
        if self.net_param.whitening:
            preprocessors.append(mean_var_normaliser)
        if self.net_param.volume_padding_size:
            preprocessors.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size))
        self.readers[0].add_preprocessing_layers(preprocessors)
Example #5
0
def preprocess(
    input_path,
    model_path,
    output_path,
    cutoff,
):
    input_path = Path(input_path)
    output_path = Path(output_path)
    input_dir = input_path.parent

    DATA_PARAM = {
        'Modality0':
        ParserNamespace(
            path_to_search=str(input_dir),
            filename_contains=('nii.gz', ),
            interp_order=0,
            pixdim=None,
            axcodes='RAS',
            loader=None,
        )
    }

    TASK_PARAM = ParserNamespace(image=('Modality0', ))
    data_partitioner = ImageSetsPartitioner()
    file_list = data_partitioner.initialise(DATA_PARAM).get_file_list()
    reader = ImageReader(['image'])
    reader.initialise(DATA_PARAM, TASK_PARAM, file_list)

    binary_masking_func = BinaryMaskingLayer(type_str='mean_plus', )

    hist_norm = HistogramNormalisationLayer(
        image_name='image',
        modalities=['Modality0'],
        model_filename=str(model_path),
        binary_masking_func=binary_masking_func,
        cutoff=cutoff,
        name='hist_norm_layer',
    )

    image = reader.output_list[0]['image']
    data = image.get_data()
    norm_image_dict, mask_dict = hist_norm({'image': data})
    data = norm_image_dict['image']
    nii = nib.Nifti1Image(data.squeeze(), image.original_affine[0])
    dst = output_path
    nii.to_filename(str(dst))
    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):

        self.data_param = data_param
        self.segmentation_param = task_param

        # initialise input image readers
        if self.is_training:
            reader_names = ('image', 'label', 'weight_map', 'sampler')
        elif self.is_inference:
            # in the inference process use `image` input only
            reader_names = ('image',)
        elif self.is_evaluation:
            reader_names = ('image', 'label', 'inferred')
        else:
            tf.logging.fatal(
                'Action `%s` not supported. Expected one of %s',
                self.action, self.SUPPORTED_PHASES)
            raise ValueError
        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(
            phase=reader_phase, action=self.action)
        self.readers = [
            ImageReader(reader_names).initialise(
                data_param, task_param, file_list) for file_list in file_lists]

        # initialise input preprocessing layers
        foreground_masking_layer = BinaryMaskingLayer(
            type_str=self.net_param.foreground_type,
            multimod_fusion=self.net_param.multimod_foreground_type,
            threshold=0.0) \
            if self.net_param.normalise_foreground_only else None
        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer) \
            if self.net_param.whitening else None
        percentile_normaliser = PercentileNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer, cutoff=self.net_param.cutoff) \
            if self.net_param.percentile_normalisation else None
        histogram_normaliser = HistogramNormalisationLayer(
            image_name='image',
            modalities=vars(task_param).get('image'),
            model_filename=self.net_param.histogram_ref_file,
            binary_masking_func=foreground_masking_layer,
            norm_type=self.net_param.norm_type,
            cutoff=self.net_param.cutoff,
            name='hist_norm_layer') \
            if (self.net_param.histogram_ref_file and
                self.net_param.normalisation) else None
        label_normalisers = None
        if self.net_param.histogram_ref_file and \
                task_param.label_normalisation:
            label_normalisers = [DiscreteLabelNormalisationLayer(
                image_name='label',
                modalities=vars(task_param).get('label'),
                model_filename=self.net_param.histogram_ref_file)]
            if self.is_evaluation:
                label_normalisers.append(
                    DiscreteLabelNormalisationLayer(
                        image_name='inferred',
                        modalities=vars(task_param).get('inferred'),
                        model_filename=self.net_param.histogram_ref_file))
                label_normalisers[-1].key = label_normalisers[0].key

        normalisation_layers = []
        if histogram_normaliser is not None:
            normalisation_layers.append(histogram_normaliser)
        if mean_var_normaliser is not None:
            normalisation_layers.append(mean_var_normaliser)
        if percentile_normaliser is not None:
            normalisation_layers.append(percentile_normaliser)
        if task_param.label_normalisation and \
                (self.is_training or not task_param.output_prob):
            normalisation_layers.extend(label_normalisers)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(PadLayer(
                image_name=SUPPORTED_INPUT,
                border=self.net_param.volume_padding_size,
                mode=self.net_param.volume_padding_mode))

        # initialise training data augmentation layers
        augmentation_layers = []
        if self.is_training:
            train_param = self.action_param
            if train_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=train_param.random_flipping_axes))
            if train_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=train_param.scaling_percentage[0],
                    max_percentage=train_param.scaling_percentage[1],
                    antialiasing=train_param.antialiasing))
            if train_param.rotation_angle or \
                    train_param.rotation_angle_x or \
                    train_param.rotation_angle_y or \
                    train_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if train_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        train_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        train_param.rotation_angle_x,
                        train_param.rotation_angle_y,
                        train_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)
            if train_param.do_elastic_deformation:
                spatial_rank = list(self.readers[0].spatial_ranks.values())[0]
                augmentation_layers.append(RandomElasticDeformationLayer(
                    spatial_rank=spatial_rank,
                    num_controlpoints=train_param.num_ctrl_points,
                    std_deformation_sigma=train_param.deformation_sigma,
                    proportion_to_augment=train_param.proportion_to_deform))

        # only add augmentation to first reader (not validation reader)
        self.readers[0].add_preprocessing_layers(
            volume_padding_layer + normalisation_layers + augmentation_layers)

        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(
                volume_padding_layer + normalisation_layers)
    def test_volume_loader(self):
        expected_T1 = np.array(
            [0.0, 8.24277910972, 21.4917343731,
             27.0551695202, 32.6186046672, 43.5081573038,
             53.3535675285, 61.9058849776, 70.0929786194,
             73.9944243858, 77.7437509974, 88.5331971492,
             100.0])
        expected_FLAIR = np.array(
            [0.0, 5.36540863446, 15.5386130103,
             20.7431912042, 26.1536608309, 36.669150376,
             44.7821246138, 50.7930589961, 56.1703089214,
             59.2393548654, 63.1565641037, 78.7271261392,
             100.0])

        reader = ImageReader(['image'])
        reader.initialise(DATA_PARAM, TASK_PARAM, file_list)
        self.assertAllClose(len(reader._file_list), 4)

        foreground_masking_layer = BinaryMaskingLayer(
            type_str='otsu_plus',
            multimod_fusion='or')
        hist_norm = HistogramNormalisationLayer(
            image_name='image',
            modalities=vars(TASK_PARAM).get('image'),
            model_filename=MODEL_FILE,
            binary_masking_func=foreground_masking_layer,
            cutoff=(0.05, 0.95),
            name='hist_norm_layer')
        if os.path.exists(MODEL_FILE):
            os.remove(MODEL_FILE)
        hist_norm.train(reader.output_list)
        out_map = hist_norm.mapping

        self.assertAllClose(out_map['T1'], expected_T1)
        self.assertAllClose(out_map['FLAIR'], expected_FLAIR)

        # normalise a uniformly sampled random image
        test_shape = (20, 20, 20, 3, 2)
        rand_image = np.random.uniform(low=-10.0, high=10.0, size=test_shape)
        norm_image = np.copy(rand_image)
        norm_image_dict, mask_dict = hist_norm({'image': norm_image})
        norm_image, mask = hist_norm(norm_image, mask_dict)
        self.assertAllClose(norm_image_dict['image'], norm_image)
        self.assertAllClose(mask_dict['image'], mask)

        # apply mean std normalisation
        mv_norm = MeanVarNormalisationLayer(
            image_name='image',
            binary_masking_func=foreground_masking_layer)
        norm_image, _ = mv_norm(norm_image, mask)
        self.assertAllClose(norm_image.shape, mask.shape)

        mv_norm = MeanVarNormalisationLayer(
            image_name='image',
            binary_masking_func=None)
        norm_image, _ = mv_norm(norm_image)

        # mapping should keep at least the order of the images
        rand_image = rand_image[:, :, :, 1, 1].flatten()
        norm_image = norm_image[:, :, :, 1, 1].flatten()

        order_before = rand_image[1:] > rand_image[:-1]
        order_after = norm_image[1:] > norm_image[:-1]
        self.assertAllClose(np.mean(norm_image), 0.0)
        self.assertAllClose(np.std(norm_image), 1.0)
        self.assertAllClose(order_before, order_after)
        if os.path.exists(MODEL_FILE):
            os.remove(MODEL_FILE)
    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):

        self.data_param = data_param
        self.segmentation_param = task_param

        # initialise input image readers
        if self.is_training:
            reader_names = ('image', 'label', 'weight', 'sampler')
        elif self.is_inference:
            # in the inference process use `image` input only
            reader_names = ('image', )
        elif self.is_evaluation:
            reader_names = ('image', 'label', 'inferred')
        else:
            tf.logging.fatal('Action `%s` not supported. Expected one of %s',
                             self.action, self.SUPPORTED_PHASES)
            raise ValueError
        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(phase=reader_phase,
                                                        action=self.action)
        self.readers = [
            ImageReader(reader_names).initialise(data_param, task_param,
                                                 file_list)
            for file_list in file_lists
        ]

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        label_normaliser = None
        if self.net_param.histogram_ref_file:
            label_normaliser = DiscreteLabelNormalisationLayer(
                image_name='label',
                modalities=vars(task_param).get('label'),
                model_filename=self.net_param.histogram_ref_file)

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation and \
                (self.is_training or not task_param.output_prob):
            normalisation_layers.append(label_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if self.action_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        self.action_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)
            if self.action_param.bias_field_range:
                bias_field_layer = RandomBiasFieldLayer()
                bias_field_layer.init_order(self.action_param.bf_order)
                bias_field_layer.init_uniform_coeff(
                    self.action_param.bias_field_range)
                augmentation_layers.append(bias_field_layer)

        volume_padding_layer = [
            PadLayer(image_name=SUPPORTED_INPUT,
                     border=self.net_param.volume_padding_size,
                     mode=self.net_param.volume_padding_mode,
                     pad_to=self.net_param.volume_padding_to_size)
        ]

        self.readers[0].add_preprocessing_layers(volume_padding_layer +
                                                 normalisation_layers +
                                                 augmentation_layers)

        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers)
    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):

        self.data_param = data_param
        self.classification_param = task_param

        if self.is_training:
            reader_names = ('image', 'label', 'sampler')
        elif self.is_inference:
            reader_names = ('image', )
        elif self.is_evaluation:
            reader_names = ('image', 'label', 'inferred')
        else:
            tf.logging.fatal('Action `%s` not supported. Expected one of %s',
                             self.action, self.SUPPORTED_PHASES)
            raise ValueError
        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(phase=reader_phase,
                                                        action=self.action)
        self.readers = [
            ImageReader(reader_names).initialise(data_param, task_param,
                                                 file_list)
            for file_list in file_lists
        ]

        foreground_masking_layer = BinaryMaskingLayer(
            type_str=self.net_param.foreground_type,
            multimod_fusion=self.net_param.multimod_foreground_type,
            threshold=0.0) \
            if self.net_param.normalise_foreground_only else None

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer) \
            if self.net_param.whitening else None
        histogram_normaliser = HistogramNormalisationLayer(
            image_name='image',
            modalities=vars(task_param).get('image'),
            model_filename=self.net_param.histogram_ref_file,
            binary_masking_func=foreground_masking_layer,
            norm_type=self.net_param.norm_type,
            cutoff=self.net_param.cutoff,
            name='hist_norm_layer') \
            if (self.net_param.histogram_ref_file and
                self.net_param.normalisation) else None

        label_normaliser = DiscreteLabelNormalisationLayer(
            image_name='label',
            modalities=vars(task_param).get('label'),
            model_filename=self.net_param.histogram_ref_file) \
            if (self.net_param.histogram_ref_file and
                task_param.label_normalisation) else None

        normalisation_layers = []
        if histogram_normaliser is not None:
            normalisation_layers.append(histogram_normaliser)
        if mean_var_normaliser is not None:
            normalisation_layers.append(mean_var_normaliser)
        if label_normaliser is not None:
            normalisation_layers.append(label_normaliser)

        augmentation_layers = []
        if self.is_training:
            train_param = self.action_param
            if train_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=train_param.random_flipping_axes))
            if train_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=train_param.scaling_percentage[0],
                        max_percentage=train_param.scaling_percentage[1],
                        antialiasing=train_param.antialiasing,
                        isotropic=train_param.isotropic_scaling))
            if train_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if train_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        train_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)

        # only add augmentation to first reader (not validation reader)
        self.readers[0].add_preprocessing_layers(normalisation_layers +
                                                 augmentation_layers)

        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(normalisation_layers)
Example #10
0
    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):
        self.data_param = data_param
        self.gan_param = task_param

        file_lists = self.get_file_lists(data_partitioner)
        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(['image', 'conditioning'])
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        elif self.is_inference:
            inference_reader = ImageReader(['conditioning'])
            inference_reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [inference_reader]
        elif self.is_evaluation:
            NotImplementedError('Evaluation is not yet '
                                'supported in this application.')

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        for reader in self.readers:
            reader.add_preprocessing_layers(normalisation_layers +
                                            augmentation_layers)
    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):
        self.data_param = data_param
        self.gan_param = task_param

        if self.is_training:
            reader_names = ('image', 'conditioning')
        elif self.is_inference:
            # in the inference process use `conditioning` input only
            reader_names = ('conditioning', )
        elif self.is_evaluation:
            tf.logging.fatal(
                'Evaluation is not yet supported in this application.')
            raise NotImplementedError
        else:
            tf.logging.fatal('Action `%s` not supported. Expected one of %s',
                             self.action, self.SUPPORTED_PHASES)
            raise ValueError
        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(phase=reader_phase,
                                                        action=self.action)
        self.readers = [
            ImageReader(reader_names).initialise(data_param, task_param,
                                                 file_list)
            for file_list in file_lists
        ]

        # initialise input preprocessing layers
        foreground_masking_layer = BinaryMaskingLayer(
            type_str=self.net_param.foreground_type,
            multimod_fusion=self.net_param.multimod_foreground_type,
            threshold=0.0) \
            if self.net_param.normalise_foreground_only else None
        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer) \
            if self.net_param.whitening else None
        histogram_normaliser = HistogramNormalisationLayer(
            image_name='image',
            modalities=vars(task_param).get('image'),
            model_filename=self.net_param.histogram_ref_file,
            binary_masking_func=foreground_masking_layer,
            norm_type=self.net_param.norm_type,
            cutoff=self.net_param.cutoff,
            name='hist_norm_layer') \
            if (self.net_param.histogram_ref_file and
                self.net_param.normalisation) else None

        normalisation_layers = []
        if histogram_normaliser is not None:
            normalisation_layers.append(histogram_normaliser)
        if mean_var_normaliser is not None:
            normalisation_layers.append(mean_var_normaliser)

        # initialise training data augmentation layers
        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1],
                        antialiasing=self.action_param.antialiasing))
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        # only add augmentation to first reader (not validation reader)
        self.readers[0].add_preprocessing_layers(normalisation_layers +
                                                 augmentation_layers)

        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(normalisation_layers)
Example #12
0
    def test_volume_loader(self):
        expected_T1 = np.array([
            0.0, 8.24277910972, 21.4917343731, 27.0551695202, 32.6186046672,
            43.5081573038, 53.3535675285, 61.9058849776, 70.0929786194,
            73.9944243858, 77.7437509974, 88.5331971492, 100.0
        ])
        expected_FLAIR = np.array([
            0.0, 5.36540863446, 15.5386130103, 20.7431912042, 26.1536608309,
            36.669150376, 44.7821246138, 50.7930589961, 56.1703089214,
            59.2393548654, 63.1565641037, 78.7271261392, 100.0
        ])

        reader = ImageReader(['image'])
        reader.initialise(DATA_PARAM, TASK_PARAM, file_list)
        self.assertAllClose(len(reader._file_list), 4)

        foreground_masking_layer = BinaryMaskingLayer(type_str='otsu_plus',
                                                      multimod_fusion='or')
        hist_norm = HistogramNormalisationLayer(
            image_name='image',
            modalities=vars(TASK_PARAM).get('image'),
            model_filename=MODEL_FILE,
            binary_masking_func=foreground_masking_layer,
            cutoff=(0.05, 0.95),
            name='hist_norm_layer')
        if os.path.exists(MODEL_FILE):
            os.remove(MODEL_FILE)
        hist_norm.train(reader.output_list)
        out_map = hist_norm.mapping

        self.assertAllClose(out_map['T1'], expected_T1)
        self.assertAllClose(out_map['FLAIR'], expected_FLAIR)

        # normalise a uniformly sampled random image
        test_shape = (20, 20, 20, 3, 2)
        rand_image = np.random.uniform(low=-10.0, high=10.0, size=test_shape)
        norm_image = np.copy(rand_image)
        norm_image_dict, mask_dict = hist_norm({'image': norm_image})
        norm_image, mask = hist_norm(norm_image, mask_dict)
        self.assertAllClose(norm_image_dict['image'], norm_image)
        self.assertAllClose(mask_dict['image'], mask)

        # apply mean std normalisation
        mv_norm = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        norm_image, _ = mv_norm(norm_image, mask)
        self.assertAllClose(norm_image.shape, mask.shape)

        mv_norm = MeanVarNormalisationLayer(image_name='image',
                                            binary_masking_func=None)
        norm_image, _ = mv_norm(norm_image)

        # mapping should keep at least the order of the images
        rand_image = rand_image[:, :, :, 1, 1].flatten()
        norm_image = norm_image[:, :, :, 1, 1].flatten()

        order_before = rand_image[1:] > rand_image[:-1]
        order_after = norm_image[1:] > norm_image[:-1]
        self.assertAllClose(np.mean(norm_image), 0.0)
        self.assertAllClose(np.std(norm_image), 1.0)
        self.assertAllClose(order_before, order_after)
        if os.path.exists(MODEL_FILE):
            os.remove(MODEL_FILE)
Example #13
0
    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):

        self.data_param = data_param
        self.segmentation_param = task_param

        file_lists = self.get_file_lists(data_partitioner)
        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader({'image', 'label', 'weight', 'sampler'})
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)

        elif self.is_inference:
            # in the inference process use image input only
            inference_reader = ImageReader({'image'})
            file_list = pd.concat([
                data_partitioner.inference_files,
                data_partitioner.validation_files
            ],
                                  axis=0)
            file_list.index = range(file_list.shape[0])
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]
        elif self.is_evaluation:
            file_list = data_partitioner.inference_files
            reader = ImageReader({'image', 'label', 'inferred'})
            reader.initialise(data_param, task_param, file_list)
            self.readers = [reader]
        else:
            raise ValueError(
                'Action `{}` not supported. Expected one of {}'.format(
                    self.action, self.SUPPORTED_ACTIONS))

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        label_normalisers = None
        if self.net_param.histogram_ref_file and \
                task_param.label_normalisation:
            label_normalisers = [
                DiscreteLabelNormalisationLayer(
                    image_name='label',
                    modalities=vars(task_param).get('label'),
                    model_filename=self.net_param.histogram_ref_file)
            ]
            if self.is_evaluation:
                label_normalisers.append(
                    DiscreteLabelNormalisationLayer(
                        image_name='inferred',
                        modalities=vars(task_param).get('inferred'),
                        model_filename=self.net_param.histogram_ref_file))
                label_normalisers[-1].key = label_normalisers[0].key

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation and \
                (self.is_training or not task_param.output_prob):
            normalisation_layers.extend(label_normalisers)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if self.action_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        self.action_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)

            # add deformation layer
            if self.action_param.do_elastic_deformation:
                spatial_rank = list(self.readers[0].spatial_ranks.values())[0]
                augmentation_layers.append(
                    RandomElasticDeformationLayer(
                        spatial_rank=spatial_rank,
                        num_controlpoints=self.action_param.num_ctrl_points,
                        std_deformation_sigma=self.action_param.
                        deformation_sigma,
                        proportion_to_augment=self.action_param.
                        proportion_to_deform))

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size))

        # only add augmentation to first reader (not validation reader)
        self.readers[0].add_preprocessing_layers(volume_padding_layer +
                                                 normalisation_layers +
                                                 augmentation_layers)

        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers)
    def initialise_dataset_loader(
            self, data_param=None, task_param=None, data_partitioner=None):

        self.data_param = data_param
        self.classification_param = task_param

        file_lists = self.get_file_lists(data_partitioner)
        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(['image', 'label', 'sampler'])
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)

        elif self.is_inference:  
            # in the inference process use image input only
            inference_reader = ImageReader(['image'])
            inference_reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [inference_reader]
        elif self.is_evaluation:
            reader = ImageReader({'image', 'label', 'inferred'})
            reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [reader]
        else:
            raise ValueError('Action `{}` not supported. Expected one of {}'
                             .format(self.action, self.SUPPORTED_ACTIONS))

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        label_normaliser = None
        if self.net_param.histogram_ref_file:
            label_normaliser = DiscreteLabelNormalisationLayer(
                image_name='label',
                modalities=vars(task_param).get('label'),
                model_filename=self.net_param.histogram_ref_file)

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation:
            normalisation_layers.append(label_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=self.action_param.scaling_percentage[0],
                    max_percentage=self.action_param.scaling_percentage[1]))
            if self.action_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if self.action_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        self.action_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)

        for reader in self.readers:
            reader.add_preprocessing_layers(
                normalisation_layers +
                augmentation_layers)
Example #15
0
    def initialise_dataset_loader(self, data_param=None, task_param=None, data_partitioner=None):

        self.data_param = data_param
        self.segmentation_param = task_param

        # read each line of csv files into an instance of Subject
        if self.is_training:
            file_lists = []
            if self.action_param.validation_every_n > 0:
                file_lists.append(data_partitioner.train_files)
                file_lists.append(data_partitioner.validation_files)
            else:
                file_lists.append(data_partitioner.all_files)

            self.readers = []
            for file_list in file_lists:
                reader = ImageReader(SUPPORTED_INPUT)
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)

        else:  # in the inference process use image input only
            inference_reader = ImageReader(['image'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_list)
            self.readers = [inference_reader]

        foreground_masking_layer = None
        if self.net_param.normalise_foreground_only:
            foreground_masking_layer = BinaryMaskingLayer(
                type_str=self.net_param.foreground_type,
                multimod_fusion=self.net_param.multimod_foreground_type,
                threshold=0.0)

        mean_var_normaliser = MeanVarNormalisationLayer(
            image_name='image', binary_masking_func=foreground_masking_layer)
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                binary_masking_func=foreground_masking_layer,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        label_normaliser = None
        if self.net_param.histogram_ref_file:
            label_normaliser = DiscreteLabelNormalisationLayer(
                image_name='label',
                modalities=vars(task_param).get('label'),
                model_filename=self.net_param.histogram_ref_file)

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)
        if task_param.label_normalisation:
            normalisation_layers.append(label_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(RandomFlipLayer(
                    flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(RandomSpatialScalingLayer(
                    min_percentage=self.action_param.scaling_percentage[0],
                    max_percentage=self.action_param.scaling_percentage[1]))
            if self.action_param.rotation_angle or \
                    self.action_param.rotation_angle_x or \
                    self.action_param.rotation_angle_y or \
                    self.action_param.rotation_angle_z:
                rotation_layer = RandomRotationLayer()
                if self.action_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        self.action_param.rotation_angle)
                else:
                    rotation_layer.init_non_uniform_angle(
                        self.action_param.rotation_angle_x,
                        self.action_param.rotation_angle_y,
                        self.action_param.rotation_angle_z)
                augmentation_layers.append(rotation_layer)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(PadLayer(
                image_name=SUPPORTED_INPUT,
                border=self.net_param.volume_padding_size))

        for reader in self.readers:
            reader.add_preprocessing_layers(
                volume_padding_layer +
                normalisation_layers +
                augmentation_layers)
    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):
        self.data_param = data_param
        self.regression_param = task_param

        file_lists = self.get_file_lists(data_partitioner)
        # read each line of csv files into an instance of Subject
        if self.is_training:
            self.readers = []
            for file_list in file_lists:
                reader = ImageReader({'image', 'output', 'weight', 'sampler'})
                reader.initialise(data_param, task_param, file_list)
                self.readers.append(reader)
        elif self.is_inference:
            inference_reader = ImageReader(['image'])
            file_list = data_partitioner.inference_files
            inference_reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [inference_reader]
        elif self.is_evaluation:
            file_list = data_partitioner.inference_files
            reader = ImageReader({'image', 'output', 'inferred'})
            reader.initialise(data_param, task_param, file_lists[0])
            self.readers = [reader]
        else:
            raise ValueError(
                'Action `{}` not supported. Expected one of {}'.format(
                    self.action, self.SUPPORTED_ACTIONS))

        mean_var_normaliser = MeanVarNormalisationLayer(image_name='image')
        histogram_normaliser = None
        if self.net_param.histogram_ref_file:
            histogram_normaliser = HistogramNormalisationLayer(
                image_name='image',
                modalities=vars(task_param).get('image'),
                model_filename=self.net_param.histogram_ref_file,
                norm_type=self.net_param.norm_type,
                cutoff=self.net_param.cutoff,
                name='hist_norm_layer')

        normalisation_layers = []
        if self.net_param.normalisation:
            normalisation_layers.append(histogram_normaliser)
        if self.net_param.whitening:
            normalisation_layers.append(mean_var_normaliser)

        augmentation_layers = []
        if self.is_training:
            if self.action_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=self.action_param.random_flipping_axes))
            if self.action_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=self.action_param.scaling_percentage[0],
                        max_percentage=self.action_param.scaling_percentage[1])
                )
            if self.action_param.rotation_angle:
                augmentation_layers.append(RandomRotationLayer())
                augmentation_layers[-1].init_uniform_angle(
                    self.action_param.rotation_angle)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size,
                         mode=self.net_param.volume_padding_mode))
        for reader in self.readers:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers +
                                            augmentation_layers)
    def initialise_dataset_loader(self,
                                  data_param=None,
                                  task_param=None,
                                  data_partitioner=None):

        self.data_param = data_param
        self.regression_param = task_param

        # initialise input image readers
        if self.is_training:
            reader_names = ('image', 'output', 'weight', 'sampler')
        elif self.is_inference:
            # in the inference process use `image` input only
            reader_names = ('image', )
        elif self.is_evaluation:
            reader_names = ('image', 'output', 'inferred')
        else:
            tf.logging.fatal('Action `%s` not supported. Expected one of %s',
                             self.action, self.SUPPORTED_PHASES)
            raise ValueError
        try:
            reader_phase = self.action_param.dataset_to_infer
        except AttributeError:
            reader_phase = None
        file_lists = data_partitioner.get_file_lists_by(phase=reader_phase,
                                                        action=self.action)
        self.readers = [
            ImageReader(reader_names).initialise(data_param, task_param,
                                                 file_list)
            for file_list in file_lists
        ]

        # initialise input preprocessing layers
        mean_var_normaliser = MeanVarNormalisationLayer(image_name='image') \
            if self.net_param.whitening else None
        histogram_normaliser = HistogramNormalisationLayer(
            image_name='image',
            modalities=vars(task_param).get('image'),
            model_filename=self.net_param.histogram_ref_file,
            norm_type=self.net_param.norm_type,
            cutoff=self.net_param.cutoff,
            name='hist_norm_layer') \
            if (self.net_param.histogram_ref_file and
                self.net_param.normalisation) else None

        normalisation_layers = []
        if histogram_normaliser is not None:
            normalisation_layers.append(histogram_normaliser)
        if mean_var_normaliser is not None:
            normalisation_layers.append(mean_var_normaliser)

        volume_padding_layer = []
        if self.net_param.volume_padding_size:
            volume_padding_layer.append(
                PadLayer(image_name=SUPPORTED_INPUT,
                         border=self.net_param.volume_padding_size,
                         mode=self.net_param.volume_padding_mode))

        # initialise training data augmentation layers
        augmentation_layers = []
        if self.is_training:
            train_param = self.action_param
            if train_param.random_flipping_axes != -1:
                augmentation_layers.append(
                    RandomFlipLayer(
                        flip_axes=train_param.random_flipping_axes))
            if train_param.scaling_percentage:
                augmentation_layers.append(
                    RandomSpatialScalingLayer(
                        min_percentage=train_param.scaling_percentage[0],
                        max_percentage=train_param.scaling_percentage[1]))
            if train_param.rotation_angle:
                rotation_layer = RandomRotationLayer()
                if train_param.rotation_angle:
                    rotation_layer.init_uniform_angle(
                        train_param.rotation_angle)
                augmentation_layers.append(rotation_layer)
            if train_param.do_elastic_deformation:
                spatial_rank = list(self.readers[0].spatial_ranks.values())[0]
                augmentation_layers.append(
                    RandomElasticDeformationLayer(
                        spatial_rank=spatial_rank,
                        num_controlpoints=train_param.num_ctrl_points,
                        std_deformation_sigma=train_param.deformation_sigma,
                        proportion_to_augment=train_param.proportion_to_deform)
                )

        # only add augmentation to first reader (not validation reader)
        self.readers[0].add_preprocessing_layers(volume_padding_layer +
                                                 normalisation_layers +
                                                 augmentation_layers)

        for reader in self.readers[1:]:
            reader.add_preprocessing_layers(volume_padding_layer +
                                            normalisation_layers)