def get_train_transform(patch_size):
    """
    data augmentation for training data, inspired by:
    https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py
    :param patch_size: shape of network's input
    :return list of transformations
    """

    train_transforms = []

    def rad(deg):
        return (-deg / 360 * 2 * np.pi, deg / 360 * 2 * np.pi)

    train_transforms.append(
        SpatialTransform_2(
            patch_size,
            (10, 10, 10),
            do_elastic_deform=True,
            deformation_scale=(0, 0.25),
            do_rotation=True,
            angle_z=rad(15),
            angle_x=(0, 0),
            angle_y=(0, 0),
            do_scale=True,
            scale=(0.75, 1.25),
            border_mode_data='constant',
            border_cval_data=0,
            border_mode_seg='constant',
            border_cval_seg=0,
            order_seg=1,
            random_crop=False,
            p_el_per_sample=0.2,
            p_rot_per_sample=0.2,
            p_scale_per_sample=0.2,
        ))

    train_transforms.append(MirrorTransform(axes=(0, 1)))

    train_transforms.append(
        BrightnessMultiplicativeTransform((0.7, 1.5),
                                          per_channel=True,
                                          p_per_sample=0.2))

    train_transforms.append(
        GammaTransform(gamma_range=(0.2, 1.0),
                       invert_image=False,
                       per_channel=False,
                       p_per_sample=0.2))

    train_transforms.append(
        GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.2))

    train_transforms.append(
        GaussianBlurTransform(blur_sigma=(0.2, 1.0),
                              different_sigma_per_channel=False,
                              p_per_channel=0.0,
                              p_per_sample=0.2))

    return Compose(train_transforms)
def get_train_transform(patch_size, prob=0.1):
    # we now create a list of transforms. These are not necessarily the best transforms to use for BraTS, this is just
    # to showcase some things
    tr_transforms = []

    # the first thing we want to run is the SpatialTransform. It reduces the size of our data to patch_size and thus
    # also reduces the computational cost of all subsequent operations. All subsequent operations do not modify the
    # shape and do not transform spatially, so no border artifacts will be introduced
    # Here we use the new SpatialTransform_2 which uses a new way of parameterizing elastic_deform
    # We use all spatial transformations with a probability of 0.2 per sample. This means that 1 - (1 - 0.1) ** 3 = 27%
    # of samples will be augmented, the rest will just be cropped
    tr_transforms.append(
        SpatialTransform_2(
            patch_size, [i // 2 for i in patch_size],
            do_elastic_deform=True, deformation_scale=(0, 0.25),
            do_rotation=True,
            angle_x=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_y=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_z=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            do_scale=True, scale=(0.75, 1.25),
            border_mode_data='constant', border_cval_data=0,
            border_mode_seg='constant', border_cval_seg=0,
            order_seg=1, order_data=3,
            random_crop=True,
            p_el_per_sample=prob, p_rot_per_sample=prob, p_scale_per_sample=prob
        )
    )

    # now we mirror along all axes
    tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))

    # brightness transform for 15% of samples
    tr_transforms.append(BrightnessMultiplicativeTransform(
        (0.7, 1.5), per_channel=True, p_per_sample=0.05+prob))

    # gamma transform. This is a nonlinear transformation of intensity values
    # (https://en.wikipedia.org/wiki/Gamma_correction)
    tr_transforms.append(GammaTransform(gamma_range=(
        0.5, 2), invert_image=False, per_channel=True, p_per_sample=0.05+prob))
    # we can also invert the image, apply the transform and then invert back
    tr_transforms.append(GammaTransform(gamma_range=(
        0.5, 2), invert_image=True, per_channel=True, p_per_sample=0.05+prob))

    # Gaussian Noise
    tr_transforms.append(GaussianNoiseTransform(
        noise_variance=(0, 0.05), p_per_sample=0.05+prob))

    # blurring. Some BraTS cases have very blurry modalities. This can simulate more patients with this problem and
    # thus make the model more robust to it
    tr_transforms.append(GaussianBlurTransform(blur_sigma=(0.5, 1.5), different_sigma_per_channel=True,
                                               p_per_channel=prob, p_per_sample=0.05+prob))

    # now we compose these transforms together
    tr_transforms = Compose(tr_transforms)
    return tr_transforms
Example #3
0
    def _augment_data(self, batch_generator, type=None):

        if self.Config.DATA_AUGMENTATION:
            num_processes = 15  # 15 is a bit faster than 8 on cluster
            # num_processes = multiprocessing.cpu_count()  # on cluster: gives all cores, not only assigned cores
        else:
            num_processes = 6

        tfs = []

        if self.Config.NORMALIZE_DATA:
            tfs.append(ZeroMeanUnitVarianceTransform(per_channel=self.Config.NORMALIZE_PER_CHANNEL))

        if self.Config.SPATIAL_TRANSFORM == "SpatialTransformPeaks":
            SpatialTransformUsed = SpatialTransformPeaks
        elif self.Config.SPATIAL_TRANSFORM == "SpatialTransformCustom":
            SpatialTransformUsed = SpatialTransformCustom
        else:
            SpatialTransformUsed = SpatialTransform

        if self.Config.DATA_AUGMENTATION:
            if type == "train":
                # patch_center_dist_from_border:
                #   if 144/2=72 -> always exactly centered; otherwise a bit off center
                #   (brain can get off image and will be cut then)
                if self.Config.DAUG_SCALE:

                    if self.Config.INPUT_RESCALING:
                        source_mm = 2  # for bb
                        target_mm = float(self.Config.RESOLUTION[:-2])
                        scale_factor = target_mm / source_mm
                        scale = (scale_factor, scale_factor)
                    else:
                        scale = (0.9, 1.5)

                    if self.Config.PAD_TO_SQUARE:
                        patch_size = self.Config.INPUT_DIM
                    else:
                        patch_size = None  # keeps dimensions of the data

                    # spatial transform automatically crops/pads to correct size
                    center_dist_from_border = int(self.Config.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
                    tfs.append(SpatialTransformUsed(patch_size,
                                                patch_center_dist_from_border=center_dist_from_border,
                                                do_elastic_deform=self.Config.DAUG_ELASTIC_DEFORM,
                                                alpha=self.Config.DAUG_ALPHA, sigma=self.Config.DAUG_SIGMA,
                                                do_rotation=self.Config.DAUG_ROTATE,
                                                angle_x=self.Config.DAUG_ROTATE_ANGLE,
                                                angle_y=self.Config.DAUG_ROTATE_ANGLE,
                                                angle_z=self.Config.DAUG_ROTATE_ANGLE,
                                                do_scale=True, scale=scale, border_mode_data='constant',
                                                border_cval_data=0,
                                                order_data=3,
                                                border_mode_seg='constant', border_cval_seg=0,
                                                order_seg=0, random_crop=True,
                                                p_el_per_sample=self.Config.P_SAMP,
                                                p_rot_per_sample=self.Config.P_SAMP,
                                                p_scale_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_RESAMPLE:
                    tfs.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), p_per_sample=0.2, per_channel=False))

                if self.Config.DAUG_RESAMPLE_LEGACY:
                    tfs.append(ResampleTransformLegacy(zoom_range=(0.5, 1)))

                if self.Config.DAUG_GAUSSIAN_BLUR:
                    tfs.append(GaussianBlurTransform(blur_sigma=self.Config.DAUG_BLUR_SIGMA,
                                                     different_sigma_per_channel=False,
                                                     p_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_NOISE:
                    tfs.append(GaussianNoiseTransform(noise_variance=self.Config.DAUG_NOISE_VARIANCE,
                                                      p_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_MIRROR:
                    tfs.append(MirrorTransform())

                if self.Config.DAUG_FLIP_PEAKS:
                    tfs.append(FlipVectorAxisTransform())

        tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float"))

        #num_cached_per_queue 1 or 2 does not really make a difference
        batch_gen = MultiThreadedAugmenter(batch_generator, Compose(tfs), num_processes=num_processes,
                                           num_cached_per_queue=1, seeds=None, pin_memory=True)
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
Example #4
0
def get_insaneDA_augmentation(dataloader_train,
                              dataloader_val,
                              patch_size,
                              params=default_3D_augmentation_params,
                              border_val_seg=-1,
                              seeds_train=None,
                              seeds_val=None,
                              order_seg=1,
                              order_data=3,
                              deep_supervision_scales=None,
                              soft_ds=False,
                              classes=None,
                              pin_memory=True):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        ignore_axes = (0, )
        tr_transforms.append(Convert3DTo2DTransform())
    else:
        ignore_axes = None

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=order_data,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=order_seg,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
    # channel gets in the way
    tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
    tr_transforms.append(
        GaussianBlurTransform((0.5, 1.5),
                              different_sigma_per_channel=True,
                              p_per_sample=0.2,
                              p_per_channel=0.5))
    tr_transforms.append(
        BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3),
                                          p_per_sample=0.15))
    tr_transforms.append(
        ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                      p_per_sample=0.15))
    tr_transforms.append(
        SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                       per_channel=True,
                                       p_per_channel=0.5,
                                       order_downsample=0,
                                       order_upsample=3,
                                       p_per_sample=0.25,
                                       ignore_axes=ignore_axes))
    tr_transforms.append(
        GammaTransform(params.get("gamma_range"),
                       True,
                       True,
                       retain_stats=params.get("gamma_retain_stats"),
                       p_per_sample=0.15))  # inverted gamma

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations"
                      ) and not None and params.get(
                          "cascade_do_cascade_augmentations"):
            if params.get("cascade_random_binary_transform_p") > 0:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        p_per_sample=params.get(
                            "cascade_random_binary_transform_p"),
                        key="data",
                        strel_size=params.get(
                            "cascade_random_binary_transform_size")))
            if params.get("cascade_remove_conn_comp_p") > 0:
                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        key="data",
                        p_per_sample=params.get("cascade_remove_conn_comp_p"),
                        fill_with_other_class_p=params.get(
                            "cascade_remove_conn_comp_max_size_percent_threshold"
                        ),
                        dont_do_if_covers_more_than_X_percent=params.get(
                            "cascade_remove_conn_comp_fill_with_other_class_p")
                    ))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))
    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
    def get_train_transforms(self) -> List[AbstractTransform]:
        # used for transpost and rot90
        matching_axes = np.array(
            [sum([i == j for j in self.patch_size]) for i in self.patch_size])
        valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0])

        tr_transforms = []

        if self.data_aug_params['selected_seg_channels'] is not None:
            tr_transforms.append(
                SegChannelSelectionTransform(
                    self.data_aug_params['selected_seg_channels']))

        if self.do_dummy_2D_aug:
            ignore_axes = (0, )
            tr_transforms.append(Convert3DTo2DTransform())
            patch_size_spatial = self.patch_size[1:]
        else:
            patch_size_spatial = self.patch_size
            ignore_axes = None

        tr_transforms.append(
            SpatialTransform(
                patch_size_spatial,
                patch_center_dist_from_border=None,
                do_elastic_deform=False,
                do_rotation=True,
                angle_x=self.data_aug_params["rotation_x"],
                angle_y=self.data_aug_params["rotation_y"],
                angle_z=self.data_aug_params["rotation_z"],
                p_rot_per_axis=0.5,
                do_scale=True,
                scale=self.data_aug_params['scale_range'],
                border_mode_data="constant",
                border_cval_data=0,
                order_data=3,
                border_mode_seg="constant",
                border_cval_seg=-1,
                order_seg=1,
                random_crop=False,
                p_el_per_sample=0.2,
                p_scale_per_sample=0.2,
                p_rot_per_sample=0.4,
                independent_scale_for_each_axis=True,
            ))

        if self.do_dummy_2D_aug:
            tr_transforms.append(Convert2DTo3DTransform())

        if np.any(matching_axes > 1):
            tr_transforms.append(
                Rot90Transform((0, 1, 2, 3),
                               axes=valid_axes,
                               data_key='data',
                               label_key='seg',
                               p_per_sample=0.5), )

        if np.any(matching_axes > 1):
            tr_transforms.append(
                TransposeAxesTransform(valid_axes,
                                       data_key='data',
                                       label_key='seg',
                                       p_per_sample=0.5))

        tr_transforms.append(
            OneOfTransform([
                MedianFilterTransform((2, 8),
                                      same_for_each_channel=False,
                                      p_per_sample=0.2,
                                      p_per_channel=0.5),
                GaussianBlurTransform((0.3, 1.5),
                                      different_sigma_per_channel=True,
                                      p_per_sample=0.2,
                                      p_per_channel=0.5)
            ]))

        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))

        tr_transforms.append(
            BrightnessTransform(0,
                                0.5,
                                per_channel=True,
                                p_per_sample=0.1,
                                p_per_channel=0.5))

        tr_transforms.append(
            OneOfTransform([
                ContrastAugmentationTransform(contrast_range=(0.5, 2),
                                              preserve_range=True,
                                              per_channel=True,
                                              data_key='data',
                                              p_per_sample=0.2,
                                              p_per_channel=0.5),
                ContrastAugmentationTransform(contrast_range=(0.5, 2),
                                              preserve_range=False,
                                              per_channel=True,
                                              data_key='data',
                                              p_per_sample=0.2,
                                              p_per_channel=0.5),
            ]))

        tr_transforms.append(
            SimulateLowResolutionTransform(zoom_range=(0.25, 1),
                                           per_channel=True,
                                           p_per_channel=0.5,
                                           order_downsample=0,
                                           order_upsample=3,
                                           p_per_sample=0.15,
                                           ignore_axes=ignore_axes))

        tr_transforms.append(
            GammaTransform((0.7, 1.5),
                           invert_image=True,
                           per_channel=True,
                           retain_stats=True,
                           p_per_sample=0.1))
        tr_transforms.append(
            GammaTransform((0.7, 1.5),
                           invert_image=True,
                           per_channel=True,
                           retain_stats=True,
                           p_per_sample=0.1))

        if self.do_mirroring:
            tr_transforms.append(MirrorTransform(self.mirror_axes))

        tr_transforms.append(
            BlankRectangleTransform([[max(1, p // 10), p // 3]
                                     for p in self.patch_size],
                                    rectangle_value=np.mean,
                                    num_rectangles=(1, 5),
                                    force_square=False,
                                    p_per_sample=0.4,
                                    p_per_channel=0.5))

        tr_transforms.append(
            BrightnessGradientAdditiveTransform(
                lambda x, y: np.exp(
                    np.random.uniform(np.log(x[y] // 6), np.log(x[y]))),
                (-0.5, 1.5),
                max_strength=lambda x, y: np.random.uniform(-5, -1)
                if np.random.uniform() < 0.5 else np.random.uniform(1, 5),
                mean_centered=False,
                same_for_all_channels=False,
                p_per_sample=0.3,
                p_per_channel=0.5))

        tr_transforms.append(
            LocalGammaTransform(
                lambda x, y: np.exp(
                    np.random.uniform(np.log(x[y] // 6), np.log(x[y]))),
                (-0.5, 1.5),
                lambda: np.random.uniform(0.01, 0.8)
                if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4),
                same_for_all_channels=False,
                p_per_sample=0.3,
                p_per_channel=0.5))

        tr_transforms.append(
            SharpeningTransform(strength=(0.1, 1),
                                same_for_each_channel=False,
                                p_per_sample=0.2,
                                p_per_channel=0.5))

        if any(self.use_mask_for_norm.values()):
            tr_transforms.append(
                MaskTransform(self.use_mask_for_norm,
                              mask_idx_in_seg=0,
                              set_outside_to=0))

        tr_transforms.append(RemoveLabelTransform(-1, 0))

        if self.data_aug_params["move_last_seg_chanel_to_data"]:
            all_class_labels = np.arange(1, self.num_classes)
            tr_transforms.append(
                MoveSegAsOneHotToData(1, all_class_labels, 'seg', 'data'))
            if self.data_aug_params["cascade_do_cascade_augmentations"]:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(channel_idx=list(
                        range(-len(all_class_labels), 0)),
                                                       p_per_sample=0.4,
                                                       key="data",
                                                       strel_size=(1, 8),
                                                       p_per_label=1))

                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(range(-len(all_class_labels), 0)),
                        key="data",
                        p_per_sample=0.2,
                        fill_with_other_class_p=0.15,
                        dont_do_if_covers_more_than_X_percent=0))

        tr_transforms.append(RenameTransform('seg', 'target', True))

        if self.regions is not None:
            tr_transforms.append(
                ConvertSegmentationToRegionsTransform(self.regions, 'target',
                                                      'target'))

        if self.deep_supervision_scales is not None:
            tr_transforms.append(
                DownsampleSegForDSTransform2(self.deep_supervision_scales,
                                             0,
                                             input_key='target',
                                             output_key='target'))

        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
        return tr_transforms
def get_insaneDA_augmentation(dataloader_train,
                              dataloader_val,
                              patch_size,
                              params=default_3D_augmentation_params,
                              border_val_seg=-1,
                              seeds_train=None,
                              seeds_val=None,
                              order_seg=1,
                              order_data=3,
                              deep_supervision_scales=None,
                              soft_ds=False,
                              classes=None,
                              pin_memory=True,
                              regions=None):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    # 'patch_size': array([288, 320]),
    # 'border_val_seg': -1,
    # 'seeds_train': None,
    # 'seeds_val': None,
    # 'order_seg': 1,
    # 'order_data': 3,
    # 'deep_supervision_scales': [[1, 1, 1],
    #                             [1.0, 0.5, 0.5],
    #                             [1.0, 0.25, 0.25],
    #                             [0.5, 0.125, 0.125],
    #                             [0.5, 0.0625, 0.0625]],
    # 'soft_ds': False,
    # 'classes': None,
    # 'pin_memory': True,
    # 'regions': None
    # params
    # {'selected_data_channels': None,
    #  'selected_seg_channels': [0],
    #  'do_elastic': True,
    #  'elastic_deform_alpha': (0.0, 300.0),
    #  'elastic_deform_sigma': (9.0, 15.0),
    #  'p_eldef': 0.1,
    #  'do_scaling': True,
    #  'scale_range': (0.65, 1.6),
    #  'independent_scale_factor_for_each_axis': True,
    #  'p_independent_scale_per_axis': 0.3,
    #  'p_scale': 0.3,
    #  'do_rotation': True,
    #  'rotation_x': (-3.141592653589793, 3.141592653589793),
    #  'rotation_y': (-0.5235987755982988, 0.5235987755982988),
    #  'rotation_z': (-0.5235987755982988, 0.5235987755982988),
    #  'rotation_p_per_axis': 1,
    #  'p_rot': 0.7,
    #  'random_crop': False,
    #  'random_crop_dist_to_border': None,
    #  'do_gamma': True,
    #  'gamma_retain_stats': True,
    #  'gamma_range': (0.5, 1.6),
    #  'p_gamma': 0.3,
    #  'do_mirror': True,
    #  'mirror_axes': (0, 1, 2),
    #  'dummy_2D': True,
    #  'mask_was_used_for_normalization': OrderedDict([(0, False)]),
    #  'border_mode_data': 'constant',
    #  'all_segmentation_labels': None,
    #  'move_last_seg_chanel_to_data': False,
    #  'cascade_do_cascade_augmentations': False,
    #  'cascade_random_binary_transform_p': 0.4,
    #  'cascade_random_binary_transform_p_per_label': 1,
    #  'cascade_random_binary_transform_size': (1, 8),
    #  'cascade_remove_conn_comp_p': 0.2,
    #  'cascade_remove_conn_comp_max_size_percent_threshold': 0.15,
    #  'cascade_remove_conn_comp_fill_with_other_class_p': 0.0,
    #  'do_additive_brightness': True,
    #  'additive_brightness_p_per_sample': 0.3,
    #  'additive_brightness_p_per_channel': 1,
    #  'additive_brightness_mu': 0,
    #  'additive_brightness_sigma': 0.2,
    #  'num_threads': 12,
    #  'num_cached_per_thread': 1,
    #  'patch_size_for_spatialtransform': array([288, 320])}

    # selected_data_channels is None
    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    # selected_seg_channels is [0]
    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    # dummy_2D is True
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        ignore_axes = (0, )
        tr_transforms.append(Convert3DTo2DTransform())
    else:
        ignore_axes = None

    tr_transforms.append(
        SpatialTransform(patch_size,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=order_data,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=order_seg,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis"),
                         p_independent_scale_per_axis=params.get(
                             "p_independent_scale_per_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
    # channel gets in the way
    tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
    tr_transforms.append(
        GaussianBlurTransform((0.5, 1.5),
                              different_sigma_per_channel=True,
                              p_per_sample=0.2,
                              p_per_channel=0.5))
    tr_transforms.append(
        BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3),
                                          p_per_sample=0.15))
    tr_transforms.append(
        ContrastAugmentationTransform(contrast_range=(0.65, 1.5),
                                      p_per_sample=0.15))
    tr_transforms.append(
        SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                       per_channel=True,
                                       p_per_channel=0.5,
                                       order_downsample=0,
                                       order_upsample=3,
                                       p_per_sample=0.25,
                                       ignore_axes=ignore_axes))
    tr_transforms.append(
        GammaTransform(params.get("gamma_range"),
                       True,
                       True,
                       retain_stats=params.get("gamma_retain_stats"),
                       p_per_sample=0.15))  # inverted gamma

    # do_additive_brightness is True
    if params.get("do_additive_brightness"):
        tr_transforms.append(
            BrightnessTransform(
                params.get("additive_brightness_mu"),
                params.get("additive_brightness_sigma"),
                True,
                p_per_sample=params.get("additive_brightness_p_per_sample"),
                p_per_channel=params.get("additive_brightness_p_per_channel")))

    # do_gamma is True
    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    # do_mirror is True
    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    # mask_was_used_for_normalization is OrderedDict([(0, False)]),
    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    # move_last_seg_chanel_to_data is False
    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations"
                      ) and not None and params.get(
                          "cascade_do_cascade_augmentations"):
            if params.get("cascade_random_binary_transform_p") > 0:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        p_per_sample=params.get(
                            "cascade_random_binary_transform_p"),
                        key="data",
                        strel_size=params.get(
                            "cascade_random_binary_transform_size")))
            if params.get("cascade_remove_conn_comp_p") > 0:
                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        key="data",
                        p_per_sample=params.get("cascade_remove_conn_comp_p"),
                        fill_with_other_class_p=params.get(
                            "cascade_remove_conn_comp_max_size_percent_threshold"
                        ),
                        dont_do_if_covers_more_than_X_percent=params.get(
                            "cascade_remove_conn_comp_fill_with_other_class_p")
                    ))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    # regions is None
    if regions is not None:
        tr_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    # deep_supervision_scales is a not None
    if deep_supervision_scales is not None:
        # soft_ds is False
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    # ========================================================
    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    # selected_data_channels is None
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    # selected_seg_channels is [0]
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # move_last_seg_chanel_to_data is False
    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    # regions is None
    if regions is not None:
        val_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    # deep_supervision_scales is not None
    if deep_supervision_scales is not None:
        # soft_ds is False
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
Example #7
0
def get_moreDA_augmentation(
    dataloader_train,
    dataloader_val,
    patch_size,
    params=default_3D_augmentation_params,
    border_val_seg=-1,
    seeds_train=None,
    seeds_val=None,
    order_seg=1,
    order_data=3,
    deep_supervision_scales=None,
    soft_ds=False,
    classes=None,
    pin_memory=True,
    anisotropy=False,
    extra_label_keys=None,
    val_mode=False,
    use_conf=False,
):
    '''
    Work as Dataloader with augmentation
    :return: train_loader, val_loader
        for each iterator, return {'data': (B, D, H, W), 'target': (B, D, H, W)}
    '''

    if not val_mode:
        assert params.get(
            'mirror'
        ) is None, "old version of params, use new keyword do_mirror"

        tr_transforms = []
        if params.get("selected_data_channels") is not None:
            tr_transforms.append(
                DataChannelSelectionTransform(
                    params.get("selected_data_channels")))
        if params.get("selected_seg_channels") is not None:
            tr_transforms.append(
                SegChannelSelectionTransform(
                    params.get("selected_seg_channels")))

        # anistropic setting
        if anisotropy or params.get("dummy_2D"):
            ignore_axes = (0, )
            tr_transforms.append(
                Convert3DTo2DTransform(extra_label_keys=extra_label_keys))
            patch_size = patch_size[1:]  # 2D patch size

            print('Using dummy2d data augmentation')
            params["elastic_deform_alpha"] = (0., 200.)
            params["elastic_deform_sigma"] = (9., 13.)
            params["rotation_x"] = (-180. / 360 * 2. * np.pi,
                                    180. / 360 * 2. * np.pi)
            params["rotation_y"] = (-0. / 360 * 2. * np.pi,
                                    0. / 360 * 2. * np.pi)
            params["rotation_z"] = (-0. / 360 * 2. * np.pi,
                                    0. / 360 * 2. * np.pi)

        else:
            ignore_axes = None

        # 1. Spatial Transform: rotation, scaling
        tr_transforms.append(
            SpatialTransform(patch_size,
                             patch_center_dist_from_border=None,
                             do_elastic_deform=params.get("do_elastic"),
                             alpha=params.get("elastic_deform_alpha"),
                             sigma=params.get("elastic_deform_sigma"),
                             do_rotation=params.get("do_rotation"),
                             angle_x=params.get("rotation_x"),
                             angle_y=params.get("rotation_y"),
                             angle_z=params.get("rotation_z"),
                             p_rot_per_axis=params.get("rotation_p_per_axis"),
                             do_scale=params.get("do_scaling"),
                             scale=params.get("scale_range"),
                             border_mode_data=params.get("border_mode_data"),
                             border_cval_data=0,
                             order_data=order_data,
                             border_mode_seg="constant",
                             border_cval_seg=border_val_seg,
                             order_seg=order_seg,
                             random_crop=params.get("random_crop"),
                             p_el_per_sample=params.get("p_eldef"),
                             p_scale_per_sample=params.get("p_scale"),
                             p_rot_per_sample=params.get("p_rot"),
                             independent_scale_for_each_axis=params.get(
                                 "independent_scale_factor_for_each_axis"),
                             extra_label_keys=extra_label_keys))

        if anisotropy or params.get("dummy_2D"):
            tr_transforms.append(
                Convert2DTo3DTransform(extra_label_keys=extra_label_keys))

        # 2. Noise Augmentation: gaussian noise, gaussian blur
        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
        tr_transforms.append(
            GaussianBlurTransform((0.5, 1.),
                                  different_sigma_per_channel=True,
                                  p_per_sample=0.2,
                                  p_per_channel=0.5))

        # 3. Color Augmentation: brightness, constrast, low resolution, gamma_transform
        tr_transforms.append(
            BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25),
                                              p_per_sample=0.15))
        if params.get("do_additive_brightness"):
            tr_transforms.append(
                BrightnessTransform(params.get("additive_brightness_mu"),
                                    params.get("additive_brightness_sigma"),
                                    True,
                                    p_per_sample=params.get(
                                        "additive_brightness_p_per_sample"),
                                    p_per_channel=params.get(
                                        "additive_brightness_p_per_channel")))
        tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
        tr_transforms.append(
            SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                           per_channel=True,
                                           p_per_channel=0.5,
                                           order_downsample=0,
                                           order_upsample=3,
                                           p_per_sample=0.25,
                                           ignore_axes=ignore_axes))
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           True,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=0.1))  # inverted gamma
        if params.get("do_gamma"):
            tr_transforms.append(
                GammaTransform(params.get("gamma_range"),
                               False,
                               True,
                               retain_stats=params.get("gamma_retain_stats"),
                               p_per_sample=params["p_gamma"]))

        # 4. Mirror Transform
        if params.get("do_mirror") or params.get("mirror"):
            tr_transforms.append(
                MirrorTransform(params.get("mirror_axes"),
                                extra_label_keys=extra_label_keys))

        # if params.get("mask_was_used_for_normalization") is not None:
        #     mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
        #     tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))

        tr_transforms.append(
            RemoveLabelTransform(-1, 0, extra_label_keys=extra_label_keys))
        tr_transforms.append(RenameTransform('data', 'image', True))
        tr_transforms.append(RenameTransform('seg', 'gt', True))

        if deep_supervision_scales is not None:
            if soft_ds:
                assert classes is not None
                tr_transforms.append(
                    DownsampleSegForDSTransform3(deep_supervision_scales, 'gt',
                                                 'gt', classes))
            else:
                tr_transforms.append(
                    DownsampleSegForDSTransform2(
                        deep_supervision_scales,
                        0,
                        0,
                        input_key='gt',
                        output_key='gt',
                        extra_label_keys=extra_label_keys))
        toTensorKeys = [
            'image', 'gt'
        ] + extra_label_keys if extra_label_keys is not None else [
            'image', 'gt'
        ]
        tr_transforms.append(NumpyToTensor(toTensorKeys, 'float'))
        tr_transforms = Compose(tr_transforms)

        if seeds_train is not None:
            seeds_train = [seeds_train] * params.get('num_threads')
        if use_conf:
            num_threads = 1
            num_cached_per_thread = 1
        else:
            num_threads, num_cached_per_thread = params.get(
                'num_threads'), params.get("num_cached_per_thread")
        batchgenerator_train = MultiThreadedAugmenter(dataloader_train,
                                                      tr_transforms,
                                                      num_threads,
                                                      num_cached_per_thread,
                                                      seeds=seeds_train,
                                                      pin_memory=pin_memory)

        val_transforms = []
        val_transforms.append(
            RemoveLabelTransform(-1, 0, extra_label_keys=extra_label_keys))
        if params.get("selected_data_channels") is not None:
            val_transforms.append(
                DataChannelSelectionTransform(
                    params.get("selected_data_channels")))
        if params.get("selected_seg_channels") is not None:
            val_transforms.append(
                SegChannelSelectionTransform(
                    params.get("selected_seg_channels")))
        val_transforms.append(RenameTransform('data', 'image', True))
        val_transforms.append(RenameTransform('seg', 'gt', True))

        if deep_supervision_scales is not None:
            if soft_ds:
                assert classes is not None
                val_transforms.append(
                    DownsampleSegForDSTransform3(deep_supervision_scales, 'gt',
                                                 'gt', classes))
            else:
                val_transforms.append(
                    DownsampleSegForDSTransform2(
                        deep_supervision_scales,
                        0,
                        0,
                        input_key='gt',
                        output_key='gt',
                        extra_label_keys=extra_label_keys))

        val_transforms.append(NumpyToTensor(toTensorKeys, 'float'))
        val_transforms = Compose(val_transforms)

        # batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
        #                                             params.get("num_cached_per_thread"),
        #                                             seeds=seeds_val, pin_memory=pin_memory)
        if seeds_val is not None:
            seeds_val = [seeds_val] * 1
        # batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, 1,
        #                                             params.get("num_cached_per_thread"),
        #                                             seeds=seeds_val, pin_memory=False)
        batchgenerator_val = SingleThreadedAugmenter(dataloader_val,
                                                     val_transforms)

    else:
        val_transforms = []
        val_transforms.append(
            RemoveLabelTransform(-1, 0, extra_label_keys=extra_label_keys))
        if params.get("selected_data_channels") is not None:
            val_transforms.append(
                DataChannelSelectionTransform(
                    params.get("selected_data_channels")))
        if params.get("selected_seg_channels") is not None:
            val_transforms.append(
                SegChannelSelectionTransform(
                    params.get("selected_seg_channels")))
        val_transforms.append(RenameTransform('data', 'image', True))
        val_transforms.append(RenameTransform('seg', 'gt', True))

        if deep_supervision_scales is not None:
            if soft_ds:
                assert classes is not None
                val_transforms.append(
                    DownsampleSegForDSTransform3(deep_supervision_scales, 'gt',
                                                 'gt', classes))
            else:
                val_transforms.append(
                    DownsampleSegForDSTransform2(
                        deep_supervision_scales,
                        0,
                        0,
                        input_key='gt',
                        output_key='gt',
                        extra_label_keys=extra_label_keys))

        toTensorKeys = [
            'image', 'gt'
        ] + extra_label_keys if extra_label_keys is not None else [
            'image', 'gt'
        ]
        val_transforms.append(NumpyToTensor(toTensorKeys, 'float'))
        val_transforms = Compose(val_transforms)
        batchgenerator_val = SingleThreadedAugmenter(dataloader_val,
                                                     val_transforms)
        if dataloader_train is not None:
            batchgenerator_train = SingleThreadedAugmenter(
                dataloader_train, val_transforms)
        else:
            batchgenerator_train = None

    return batchgenerator_train, batchgenerator_val
spatial_transform = SpatialTransform(
    img.shape,
    np.array(img.shape) // 2,
    do_elastic_deform=False,
    do_rotation=True,
    angle_z=(0, 2 * np.pi),  # 旋转
    do_scale=True,
    scale=(0.3, 3.),  # 缩放
    border_mode_data='constant',
    border_cval_data=0,
    order_data=1,
    random_crop=False)
my_transforms.append(spatial_transform)
GaussianNoise = GaussianNoiseTransform()  # 高斯噪声
my_transforms.append(GaussianNoise)
GaussianBlur = GaussianBlurTransform()  # 高斯模糊
my_transforms.append(GaussianBlur)
Brightness = BrightnessTransform(0, 0.2)  # 亮度
my_transforms.append(Brightness)
brightness_transform = ContrastAugmentationTransform(
    (0.3, 3.), preserve_range=True)  # 对比度
my_transforms.append(brightness_transform)
SimulateLowResolution = SimulateLowResolutionTransform()  # 低分辨率
my_transforms.append(SimulateLowResolution)
Gamma = GammaTransform()  # 伽马增强
my_transforms.append(Gamma)
mirror_transform = MirrorTransform(axes=(0, 1))  # 镜像
my_transforms.append(mirror_transform)
all_transforms = Compose(my_transforms)
multithreaded_generator = MultiThreadedAugmenter(batchgen, all_transforms, 1,
                                                 2)