def get_no_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1):
    """
    use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation
    :param dataloader_train:
    :param dataloader_val:
    :param patch_size:
    :param params:
    :param border_val_seg:
    :return:
    """
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    tr_transforms.append(RenameTransform('seg', 'target', True))
    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
                                                  params.get("num_cached_per_thread"),
                                                  seeds=range(params.get('num_threads')), pin_memory=True)
    batchgenerator_train.restart()

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))

    val_transforms.append(RenameTransform('seg', 'target', True))
    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads')//2, 1),
                                                params.get("num_cached_per_thread"),
                                                seeds=range(max(params.get('num_threads')//2, 1)), pin_memory=True)
    batchgenerator_val.restart()
    return batchgenerator_train, batchgenerator_val
Example #2
0
 def get_validation_transforms(self):
     val_transforms = []
     if self.params.get("selected_data_channels"):
         val_transforms.append(
             DataChannelSelectionTransform(
                 self.params.get("selected_data_channels")))
     if self.params.get("selected_seg_channels"):
         val_transforms.append(
             SegChannelSelectionTransform(
                 self.params.get("selected_seg_channels")))
     val_transforms.append(CenterCropTransform(self.patch_size))
     val_transforms.append(RemoveLabelTransform(-1, 0))
     val_transforms.append(RenameTransform('seg', 'target', True))
     val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
     return Compose(val_transforms)
def get_default_augmentation_withEDT(dataloader_train,
                                     dataloader_val,
                                     patch_size,
                                     idx_of_edts,
                                     params=default_3D_augmentation_params,
                                     border_val_seg=-1,
                                     pin_memory=True,
                                     seeds_train=None,
                                     seeds_val=None):
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=3,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=1,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot")))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())
    """
    ##############################################################
    ##############################################################
    Here we insert moving the EDT to a different key so that it does not get intensity transformed
    ##############################################################
    ##############################################################
    """
    tr_transforms.append(
        AppendChannelsTransform("data",
                                "bound",
                                idx_of_edts,
                                remove_from_input=True))

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get(
                "advanced_pyramid_augmentations") and not None and params.get(
                    "advanced_pyramid_augmentations"):
            tr_transforms.append(
                ApplyRandomBinaryOperatorTransform(channel_idx=list(
                    range(-len(params.get("all_segmentation_labels")), 0)),
                                                   p_per_sample=0.4,
                                                   key="data",
                                                   strel_size=(1, 8)))
            tr_transforms.append(
                RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                    channel_idx=list(
                        range(-len(params.get("all_segmentation_labels")), 0)),
                    key="data",
                    p_per_sample=0.2,
                    fill_with_other_class_p=0.0,
                    dont_do_if_covers_more_than_X_percent=0.15))

    tr_transforms.append(RenameTransform('seg', 'target', True))
    tr_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))
    """
    ##############################################################
    ##############################################################
    Here we insert moving the EDT to a different key
    ##############################################################
    ##############################################################
    """
    val_transforms.append(
        AppendChannelsTransform("data",
                                "bound",
                                idx_of_edts,
                                remove_from_input=True))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))
    val_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
Example #4
0
    def get_training_transforms(self):
        assert self.params.get(
            'mirror'
        ) is None, "old version of params, use new keyword do_mirror"

        tr_transforms = []

        if self.params.get("selected_data_channels"):
            tr_transforms.append(
                DataChannelSelectionTransform(
                    self.params.get("selected_data_channels")))
        if self.params.get("selected_seg_channels"):
            tr_transforms.append(
                SegChannelSelectionTransform(
                    self.params.get("selected_seg_channels")))

        # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
        if self.params.get("dummy_2D", False):
            ignore_axes = (0, )
            tr_transforms.append(Convert3DTo2DTransform())
        else:
            ignore_axes = None

        tr_transforms.append(
            SpatialTransform(
                self._spatial_transform_patch_size,
                patch_center_dist_from_border=None,
                do_elastic_deform=self.params.get("do_elastic"),
                alpha=self.params.get("elastic_deform_alpha"),
                sigma=self.params.get("elastic_deform_sigma"),
                do_rotation=self.params.get("do_rotation"),
                angle_x=self.params.get("rotation_x"),
                angle_y=self.params.get("rotation_y"),
                angle_z=self.params.get("rotation_z"),
                do_scale=self.params.get("do_scaling"),
                scale=self.params.get("scale_range"),
                order_data=self.params.get("order_data"),
                border_mode_data=self.params.get("border_mode_data"),
                border_cval_data=self.params.get("border_cval_data"),
                order_seg=self.params.get("order_seg"),
                border_mode_seg=self.params.get("border_mode_seg"),
                border_cval_seg=self.params.get("border_cval_seg"),
                random_crop=self.params.get("random_crop"),
                p_el_per_sample=self.params.get("p_eldef"),
                p_scale_per_sample=self.params.get("p_scale"),
                p_rot_per_sample=self.params.get("p_rot"),
                independent_scale_for_each_axis=self.params.get(
                    "independent_scale_factor_for_each_axis"),
            ))

        if self.params.get("dummy_2D"):
            tr_transforms.append(Convert2DTo3DTransform())

        # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
        # channel gets in the way
        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
        tr_transforms.append(
            GaussianBlurTransform((0.5, 1.5),
                                  different_sigma_per_channel=True,
                                  p_per_sample=0.2,
                                  p_per_channel=0.5), )
        tr_transforms.append(
            BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.3),
                                              p_per_sample=0.15))
        if self.params.get("do_additive_brightness"):
            tr_transforms.append(
                BrightnessTransform(
                    self.params.get("additive_brightness_mu"),
                    self.params.get("additive_brightness_sigma"),
                    True,
                    p_per_sample=self.params.get(
                        "additive_brightness_p_per_sample"),
                    p_per_channel=self.params.get(
                        "additive_brightness_p_per_channel")))
        tr_transforms.append(
            ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                          p_per_sample=0.15))
        tr_transforms.append(
            SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                           per_channel=True,
                                           p_per_channel=0.5,
                                           order_downsample=0,
                                           order_upsample=3,
                                           p_per_sample=0.25,
                                           ignore_axes=ignore_axes), )
        tr_transforms.append(
            GammaTransform(self.params.get("gamma_range"),
                           True,
                           True,
                           retain_stats=self.params.get("gamma_retain_stats"),
                           p_per_sample=0.15))  # inverted gamma

        if self.params.get("do_gamma"):
            tr_transforms.append(
                GammaTransform(
                    self.params.get("gamma_range"),
                    False,
                    True,
                    retain_stats=self.params.get("gamma_retain_stats"),
                    p_per_sample=self.params["p_gamma"]))
        if self.params.get("do_mirror") or self.params.get("mirror"):
            tr_transforms.append(
                MirrorTransform(self.params.get("mirror_axes")))
        if self.params.get("use_mask_for_norm"):
            use_mask_for_norm = self.params.get("use_mask_for_norm")
            tr_transforms.append(
                MaskTransform(use_mask_for_norm,
                              mask_idx_in_seg=0,
                              set_outside_to=0))

        tr_transforms.append(RemoveLabelTransform(-1, 0))
        tr_transforms.append(RenameTransform('seg', 'target', True))
        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
        return Compose(tr_transforms)
Example #5
0
    def get_training_transforms(self):
        assert self.params.get(
            'mirror'
        ) is None, "old version of params, use new keyword do_mirror"
        tr_transforms = []

        if self.params.get("selected_data_channels"):
            tr_transforms.append(
                DataChannelSelectionTransform(
                    self.params.get("selected_data_channels")))

        if self.params.get("selected_seg_channels"):
            tr_transforms.append(
                SegChannelSelectionTransform(
                    self.params.get("selected_seg_channels")))

        if self.params.get("dummy_2D", False):
            # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
            tr_transforms.append(Convert3DTo2DTransform())

        tr_transforms.append(
            SpatialTransform(
                self._spatial_transform_patch_size,
                patch_center_dist_from_border=None,
                do_elastic_deform=self.params.get("do_elastic"),
                alpha=self.params.get("elastic_deform_alpha"),
                sigma=self.params.get("elastic_deform_sigma"),
                do_rotation=self.params.get("do_rotation"),
                angle_x=self.params.get("rotation_x"),
                angle_y=self.params.get("rotation_y"),
                angle_z=self.params.get("rotation_z"),
                do_scale=self.params.get("do_scaling"),
                scale=self.params.get("scale_range"),
                order_data=self.params.get("order_data"),
                border_mode_data=self.params.get("border_mode_data"),
                border_cval_data=self.params.get("border_cval_data"),
                order_seg=self.params.get("order_seg"),
                border_mode_seg=self.params.get("border_mode_seg"),
                border_cval_seg=self.params.get("border_cval_seg"),
                random_crop=self.params.get("random_crop"),
                p_el_per_sample=self.params.get("p_eldef"),
                p_scale_per_sample=self.params.get("p_scale"),
                p_rot_per_sample=self.params.get("p_rot"),
                independent_scale_for_each_axis=self.params.get(
                    "independent_scale_factor_for_each_axis"),
            ))

        if self.params.get("dummy_2D", False):
            tr_transforms.append(Convert2DTo3DTransform())

        if self.params.get("do_gamma", False):
            tr_transforms.append(
                GammaTransform(
                    self.params.get("gamma_range"),
                    False,
                    True,
                    retain_stats=self.params.get("gamma_retain_stats"),
                    p_per_sample=self.params["p_gamma"]))

        if self.params.get("do_mirror", False):
            tr_transforms.append(
                MirrorTransform(self.params.get("mirror_axes")))

        if self.params.get("use_mask_for_norm"):
            use_mask_for_norm = self.params.get("use_mask_for_norm")
            tr_transforms.append(
                MaskTransform(use_mask_for_norm,
                              mask_idx_in_seg=0,
                              set_outside_to=0))

        tr_transforms.append(RemoveLabelTransform(-1, 0))
        tr_transforms.append(RenameTransform('seg', 'target', True))
        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
        return Compose(tr_transforms)
Example #6
0
def Transforms(patch_size,
               params=default_3D_augmentation_params,
               border_val_seg=-1):
    tr_transforms = []
    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(params.get("selected_data_channels"),
                                          data_key="data"))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())
    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=3,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=1,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot")))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    tr_transforms.append(MirrorTransform(params.get("mirror_axes")))
    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get(
                "advanced_pyramid_augmentations") and not None and params.get(
                    "advanced_pyramid_augmentations"):
            tr_transforms.append(
                ApplyRandomBinaryOperatorTransform(channel_idx=list(
                    range(-len(params.get("all_segmentation_labels")), 0)),
                                                   p_per_sample=0.4,
                                                   key="data",
                                                   strel_size=(1, 8)))
            tr_transforms.append(
                RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                    channel_idx=list(
                        range(-len(params.get("all_segmentation_labels")), 0)),
                    key="data",
                    p_per_sample=0.2,
                    fill_with_other_class_p=0.0,
                    dont_do_if_covers_more_than_X_percent=0.15))

    tr_transforms.append(RenameTransform('seg', 'target', True))
    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)
    return tr_transforms