Exemple #1
0
    def get_training_transforms(self):
        assert self.params.get(
            'mirror'
        ) is None, "old version of params, use new keyword do_mirror"

        tr_transforms = []

        if self.params.get("selected_data_channels"):
            tr_transforms.append(
                DataChannelSelectionTransform(
                    self.params.get("selected_data_channels")))
        if self.params.get("selected_seg_channels"):
            tr_transforms.append(
                SegChannelSelectionTransform(
                    self.params.get("selected_seg_channels")))

        # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
        if self.params.get("dummy_2D", False):
            ignore_axes = (0, )
            tr_transforms.append(Convert3DTo2DTransform())
        else:
            ignore_axes = None

        tr_transforms.append(
            SpatialTransform(
                self._spatial_transform_patch_size,
                patch_center_dist_from_border=None,
                do_elastic_deform=self.params.get("do_elastic"),
                alpha=self.params.get("elastic_deform_alpha"),
                sigma=self.params.get("elastic_deform_sigma"),
                do_rotation=self.params.get("do_rotation"),
                angle_x=self.params.get("rotation_x"),
                angle_y=self.params.get("rotation_y"),
                angle_z=self.params.get("rotation_z"),
                do_scale=self.params.get("do_scaling"),
                scale=self.params.get("scale_range"),
                order_data=self.params.get("order_data"),
                border_mode_data=self.params.get("border_mode_data"),
                border_cval_data=self.params.get("border_cval_data"),
                order_seg=self.params.get("order_seg"),
                border_mode_seg=self.params.get("border_mode_seg"),
                border_cval_seg=self.params.get("border_cval_seg"),
                random_crop=self.params.get("random_crop"),
                p_el_per_sample=self.params.get("p_eldef"),
                p_scale_per_sample=self.params.get("p_scale"),
                p_rot_per_sample=self.params.get("p_rot"),
                independent_scale_for_each_axis=self.params.get(
                    "independent_scale_factor_for_each_axis"),
            ))

        if self.params.get("dummy_2D"):
            tr_transforms.append(Convert2DTo3DTransform())

        # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
        # channel gets in the way
        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
        tr_transforms.append(
            GaussianBlurTransform((0.5, 1.5),
                                  different_sigma_per_channel=True,
                                  p_per_sample=0.2,
                                  p_per_channel=0.5), )
        tr_transforms.append(
            BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.3),
                                              p_per_sample=0.15))
        if self.params.get("do_additive_brightness"):
            tr_transforms.append(
                BrightnessTransform(
                    self.params.get("additive_brightness_mu"),
                    self.params.get("additive_brightness_sigma"),
                    True,
                    p_per_sample=self.params.get(
                        "additive_brightness_p_per_sample"),
                    p_per_channel=self.params.get(
                        "additive_brightness_p_per_channel")))
        tr_transforms.append(
            ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                          p_per_sample=0.15))
        tr_transforms.append(
            SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                           per_channel=True,
                                           p_per_channel=0.5,
                                           order_downsample=0,
                                           order_upsample=3,
                                           p_per_sample=0.25,
                                           ignore_axes=ignore_axes), )
        tr_transforms.append(
            GammaTransform(self.params.get("gamma_range"),
                           True,
                           True,
                           retain_stats=self.params.get("gamma_retain_stats"),
                           p_per_sample=0.15))  # inverted gamma

        if self.params.get("do_gamma"):
            tr_transforms.append(
                GammaTransform(
                    self.params.get("gamma_range"),
                    False,
                    True,
                    retain_stats=self.params.get("gamma_retain_stats"),
                    p_per_sample=self.params["p_gamma"]))
        if self.params.get("do_mirror") or self.params.get("mirror"):
            tr_transforms.append(
                MirrorTransform(self.params.get("mirror_axes")))
        if self.params.get("use_mask_for_norm"):
            use_mask_for_norm = self.params.get("use_mask_for_norm")
            tr_transforms.append(
                MaskTransform(use_mask_for_norm,
                              mask_idx_in_seg=0,
                              set_outside_to=0))

        tr_transforms.append(RemoveLabelTransform(-1, 0))
        tr_transforms.append(RenameTransform('seg', 'target', True))
        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
        return Compose(tr_transforms)
def get_insaneDA_augmentation2(dataloader_train,
                               dataloader_val,
                               patch_size,
                               params=default_3D_augmentation_params,
                               border_val_seg=-1,
                               seeds_train=None,
                               seeds_val=None,
                               order_seg=1,
                               order_data=3,
                               deep_supervision_scales=None,
                               soft_ds=False,
                               classes=None,
                               pin_memory=True,
                               regions=None):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        ignore_axes = (0, )
        tr_transforms.append(Convert3DTo2DTransform())
    else:
        ignore_axes = None

    tr_transforms.append(
        SpatialTransform_2(
            patch_size,
            patch_center_dist_from_border=None,
            do_elastic_deform=params.get("do_elastic"),
            deformation_scale=params.get("eldef_deformation_scale"),
            do_rotation=params.get("do_rotation"),
            angle_x=params.get("rotation_x"),
            angle_y=params.get("rotation_y"),
            angle_z=params.get("rotation_z"),
            do_scale=params.get("do_scaling"),
            scale=params.get("scale_range"),
            border_mode_data=params.get("border_mode_data"),
            border_cval_data=0,
            order_data=order_data,
            border_mode_seg="constant",
            border_cval_seg=border_val_seg,
            order_seg=order_seg,
            random_crop=params.get("random_crop"),
            p_el_per_sample=params.get("p_eldef"),
            p_scale_per_sample=params.get("p_scale"),
            p_rot_per_sample=params.get("p_rot"),
            independent_scale_for_each_axis=params.get(
                "independent_scale_factor_for_each_axis"),
            p_independent_scale_per_axis=params.get(
                "p_independent_scale_per_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
    # channel gets in the way
    tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
    tr_transforms.append(
        GaussianBlurTransform((0.5, 1.5),
                              different_sigma_per_channel=True,
                              p_per_sample=0.2,
                              p_per_channel=0.5))
    tr_transforms.append(
        BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3),
                                          p_per_sample=0.15))
    tr_transforms.append(
        ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                      p_per_sample=0.15))
    tr_transforms.append(
        SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                       per_channel=True,
                                       p_per_channel=0.5,
                                       order_downsample=0,
                                       order_upsample=3,
                                       p_per_sample=0.25,
                                       ignore_axes=ignore_axes))
    tr_transforms.append(
        GammaTransform(params.get("gamma_range"),
                       True,
                       True,
                       retain_stats=params.get("gamma_retain_stats"),
                       p_per_sample=0.15))  # inverted gamma

    if params.get("do_additive_brightness"):
        tr_transforms.append(
            BrightnessTransform(
                params.get("additive_brightness_mu"),
                params.get("additive_brightness_sigma"),
                True,
                p_per_sample=params.get("additive_brightness_p_per_sample"),
                p_per_channel=params.get("additive_brightness_p_per_channel")))

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations"
                      ) and not None and params.get(
                          "cascade_do_cascade_augmentations"):
            if params.get("cascade_random_binary_transform_p") > 0:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        p_per_sample=params.get(
                            "cascade_random_binary_transform_p"),
                        key="data",
                        strel_size=params.get(
                            "cascade_random_binary_transform_size")))
            if params.get("cascade_remove_conn_comp_p") > 0:
                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        key="data",
                        p_per_sample=params.get("cascade_remove_conn_comp_p"),
                        fill_with_other_class_p=params.get(
                            "cascade_remove_conn_comp_max_size_percent_threshold"
                        ),
                        dont_do_if_covers_more_than_X_percent=params.get(
                            "cascade_remove_conn_comp_fill_with_other_class_p")
                    ))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)
    #batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
Exemple #3
0
 def run(self, img_data, seg_data):
     # Create a parser for the batchgenerators module
     data_generator = DataParser(img_data, seg_data)
     # Initialize empty transform list
     transforms = []
     # Add mirror augmentation
     if self.mirror:
         aug_mirror = MirrorTransform(axes=self.config_mirror_axes)
         transforms.append(aug_mirror)
     # Add contrast augmentation
     if self.contrast:
         aug_contrast = ContrastAugmentationTransform(
             self.config_contrast_range,
             preserve_range=True,
             per_channel=True,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_contrast)
     # Add brightness augmentation
     if self.brightness:
         aug_brightness = BrightnessMultiplicativeTransform(
             self.config_brightness_range,
             per_channel=True,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_brightness)
     # Add gamma augmentation
     if self.gamma:
         aug_gamma = GammaTransform(self.config_gamma_range,
                                    invert_image=False,
                                    per_channel=True,
                                    retain_stats=True,
                                    p_per_sample=self.config_p_per_sample)
         transforms.append(aug_gamma)
     # Add gaussian noise augmentation
     if self.gaussian_noise:
         aug_gaussian_noise = GaussianNoiseTransform(
             self.config_gaussian_noise_range,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_gaussian_noise)
     # Add spatial transformations as augmentation
     # (rotation, scaling, elastic deformation)
     if self.rotations or self.scaling or self.elastic_deform or \
         self.cropping:
         # Identify patch shape (full image or cropping)
         if self.cropping: patch_shape = self.cropping_patch_shape
         else: patch_shape = img_data[0].shape[0:-1]
         # Assembling the spatial transformation
         aug_spatial_transform = SpatialTransform(
             patch_shape, [i // 2 for i in patch_shape],
             do_elastic_deform=self.elastic_deform,
             alpha=self.config_elastic_deform_alpha,
             sigma=self.config_elastic_deform_sigma,
             do_rotation=self.rotations,
             angle_x=self.config_rotations_angleX,
             angle_y=self.config_rotations_angleY,
             angle_z=self.config_rotations_angleZ,
             do_scale=self.scaling,
             scale=self.config_scaling_range,
             border_mode_data='constant',
             border_cval_data=0,
             border_mode_seg='constant',
             border_cval_seg=0,
             order_data=3,
             order_seg=0,
             p_el_per_sample=self.config_p_per_sample,
             p_rot_per_sample=self.config_p_per_sample,
             p_scale_per_sample=self.config_p_per_sample,
             random_crop=self.cropping)
         # Append spatial transformation to transformation list
         transforms.append(aug_spatial_transform)
     # Compose the batchgenerators transforms
     all_transforms = Compose(transforms)
     # Assemble transforms into a augmentation generator
     augmentation_generator = SingleThreadedAugmenter(
         data_generator, all_transforms)
     # Perform the data augmentation x times (x = cycles)
     aug_img_data = None
     aug_seg_data = None
     for i in range(0, self.cycles):
         # Run the computation process for the data augmentations
         augmentation = next(augmentation_generator)
         # Access augmentated data from the batchgenerators data structure
         if aug_img_data is None and aug_seg_data is None:
             aug_img_data = augmentation["data"]
             aug_seg_data = augmentation["seg"]
         # Concatenate the new data augmentated data with the cached data
         else:
             aug_img_data = np.concatenate(
                 (augmentation["data"], aug_img_data), axis=0)
             aug_seg_data = np.concatenate(
                 (augmentation["seg"], aug_seg_data), axis=0)
     # Transform data from channel-first back to channel-last structure
     # Data structure channel-first 3D:  (batch, channel, x, y, z)
     # Data structure channel-last 3D:   (batch, x, y, z, channel)
     aug_img_data = np.moveaxis(aug_img_data, 1, -1)
     aug_seg_data = np.moveaxis(aug_seg_data, 1, -1)
     # Return augmentated image and segmentation data
     return aug_img_data, aug_seg_data
Exemple #4
0
def get_transforms(mode="train",
                   n_channels=1,
                   target_size=128,
                   add_resize=False,
                   add_noise=False,
                   mask_type="",
                   batch_size=16,
                   rotate=True,
                   elastic_deform=True,
                   rnd_crop=False,
                   color_augment=True):
    tranform_list = []
    noise_list = []

    if mode == "train":

        tranform_list = [
            FillupPadTransform(min_size=(n_channels, target_size + 5,
                                         target_size + 5)),
            ResizeTransform(target_size=(target_size + 1, target_size + 1),
                            order=1,
                            concatenate_list=True),

            # RandomCropTransform(crop_size=(target_size + 5, target_size + 5)),
            MirrorTransform(axes=(2, )),
            ReshapeTransform(new_shape=(1, -1, "h", "w")),
            SpatialTransform(patch_size=(target_size, target_size),
                             random_crop=rnd_crop,
                             patch_center_dist_from_border=target_size // 2,
                             do_elastic_deform=elastic_deform,
                             alpha=(0., 100.),
                             sigma=(10., 13.),
                             do_rotation=rotate,
                             angle_x=(-0.1, 0.1),
                             angle_y=(0, 1e-8),
                             angle_z=(0, 1e-8),
                             scale=(0.9, 1.2),
                             border_mode_data="nearest",
                             border_mode_seg="nearest"),
            ReshapeTransform(new_shape=(batch_size, -1, "h", "w"))
        ]
        if color_augment:
            tranform_list += [  # BrightnessTransform(mu=0, sigma=0.2),
                BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1))
            ]

        tranform_list += [
            GaussianNoiseTransform(noise_variance=(0., 0.05)),
            ClipValueRange(min=-1.5, max=1.5),
        ]

        noise_list = []
        if mask_type == "gaussian":
            noise_list += [GaussianNoiseTransform(noise_variance=(0., 0.2))]

    elif mode == "val":
        tranform_list = [
            FillupPadTransform(min_size=(n_channels, target_size + 5,
                                         target_size + 5)),
            ResizeTransform(target_size=(target_size + 1, target_size + 1),
                            order=1,
                            concatenate_list=True),
            CenterCropTransform(crop_size=(target_size, target_size)),
            ClipValueRange(min=-1.5, max=1.5),
            # BrightnessTransform(mu=0, sigma=0.2),
            # BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1)),
            CopyTransform({"data": "data_clean"}, copy=True)
        ]

        noise_list += []

    if add_noise:
        tranform_list = tranform_list + noise_list

    tranform_list.append(NumpyToTensor())

    return Compose(tranform_list)