def get_train_transform(patch_size):
    tr_transforms = []

    tr_transforms.append(
        SpatialTransform_2(
            patch_size, [i // 2 for i in patch_size],
            do_elastic_deform=True,
            deformation_scale=(0, 0.05),
            do_rotation=True,
            angle_x=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi),
            angle_y=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi),
            angle_z=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi),
            do_scale=True,
            scale=(0.75, 1.25),
            border_mode_data='constant',
            border_cval_data=-2.34,
            border_mode_seg='constant',
            border_cval_seg=0))

    tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))
    tr_transforms.append(
        BrightnessMultiplicativeTransform((0.7, 1.5),
                                          per_channel=True,
                                          p_per_sample=0.15))
    tr_transforms.append(
        GammaTransform(gamma_range=(0.5, 2),
                       invert_image=True,
                       per_channel=True,
                       p_per_sample=0.15))
    tr_transforms.append(
        GaussianNoiseTransform(noise_variance=(0, 0.15), p_per_sample=0.15))
    tr_transforms = Compose(tr_transforms)
    return tr_transforms
Пример #2
0
def get_train_transform(patch_size):
    tr_transforms = []
    tr_transforms.append(
        SpatialTransform_2(
            None, [i // 2 for i in patch_size],
            do_elastic_deform=False,
            deformation_scale=(0, 0.25),
            do_rotation=True,
            angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            do_scale=False,
            scale=(0.8, 1.2),
            border_mode_data='constant',
            border_cval_data=0,
            border_mode_seg='constant',
            border_cval_seg=0,
            order_seg=1,
            order_data=3,
            random_crop=False,
            p_el_per_sample=0.3,
            p_rot_per_sample=0.3,
            p_scale_per_sample=0.3))
    tr_transforms.append(
        RndTransform(MirrorTransform(axes=(0, 1, 2)), prob=0.3))
    tr_transforms = Compose(transforms=tr_transforms)
    return tr_transforms
    def test_random_distributions_2D(self):
        ### test whether all 4 possible mirrorings occur in approximately equal frquencies in 2D

        batch_gen = BasicDataLoader((self.x_2D, self.y_2D), self.batch_size, number_of_threads_in_multithreaded=None)
        batch_gen = SingleThreadedAugmenter(batch_gen, MirrorTransform((0, 1)))

        counts = np.zeros(shape=(4,))

        for b in range(self.num_batches):
            batch = next(batch_gen)

            for ix in range(self.batch_size):
                if (batch['data'][ix, :, :, :] == self.cam_left).all():
                    counts[0] = counts[0] + 1

                elif (batch['data'][ix, :, :, :] == self.cam_updown).all():
                    counts[1] = counts[1] + 1

                elif (batch['data'][ix, :, :, :] == self.cam_updown_left).all():
                    counts[2] = counts[2] + 1

                elif (batch['data'][ix, :, :, :] == self.cam).all():
                    counts[3] = counts[3] + 1

        self.assertTrue([1 if (2200 < c < 2800) else 0 for c in counts] == [1]*4, "2D Images were not mirrored along "
                                                                                  "all axes with equal probability. "
                                                                                  "This may also indicate that "
                                                                                  "mirroring is not working")
Пример #4
0
    def test_image_pipeline_and_pin_memory(self):
        '''
        This just should not crash
        :return:
        '''
        try:
            import torch
        except ImportError:
            '''dont test if torch is not installed'''
            return

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(
            TransposeAxesTransform(transpose_any_of_these=(0, 1),
                                   p_per_sample=0.5))
        tr_transforms.append(NumpyToTensor(keys='data', cast_to='float'))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True)

        for _ in range(50):
            res = mt.next()

        assert isinstance(res['data'], torch.Tensor)
        assert res['data'].is_pinned()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)
Пример #5
0
def get_transforms(
        patch_shape=(256, 320), other_transforms=None, random_crop=False):
    """
    Initializes the transforms for training.
    Args:
        patch_shape:
        other_transforms: List of transforms that you would like to add (optional). Defaults to None.
        random_crop (boolean): whether or not you want to random crop or center crop. Currently, the Transformed3DGenerator
        only supports random cropping. Transformed2DGenerator supports both random_crop = True and False.
    """
    ndim = len(patch_shape)
    spatial_transform = SpatialTransform(patch_shape,
                                         do_elastic_deform=True,
                                         alpha=(0., 1500.),
                                         sigma=(30., 80.),
                                         do_rotation=True,
                                         angle_z=(0, 2 * np.pi),
                                         do_scale=True,
                                         scale=(0.75, 2.),
                                         border_mode_data="nearest",
                                         border_cval_data=0,
                                         order_data=1,
                                         random_crop=random_crop,
                                         p_el_per_sample=0.1,
                                         p_scale_per_sample=0.1,
                                         p_rot_per_sample=0.1)
    mirror_transform = MirrorTransform(axes=(0, 1))
    transforms_list = [spatial_transform, mirror_transform]
    if other_transforms is not None:
        transforms_list = transforms_list + other_transforms
    composed = Compose(transforms_list)
    return composed
def get_train_transform(patch_size):
    """
    data augmentation for training data, inspired by:
    https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py
    :param patch_size: shape of network's input
    :return list of transformations
    """

    train_transforms = []

    def rad(deg):
        return (-deg / 360 * 2 * np.pi, deg / 360 * 2 * np.pi)

    train_transforms.append(
        SpatialTransform_2(
            patch_size,
            (10, 10, 10),
            do_elastic_deform=True,
            deformation_scale=(0, 0.25),
            do_rotation=True,
            angle_z=rad(15),
            angle_x=(0, 0),
            angle_y=(0, 0),
            do_scale=True,
            scale=(0.75, 1.25),
            border_mode_data='constant',
            border_cval_data=0,
            border_mode_seg='constant',
            border_cval_seg=0,
            order_seg=1,
            random_crop=False,
            p_el_per_sample=0.2,
            p_rot_per_sample=0.2,
            p_scale_per_sample=0.2,
        ))

    train_transforms.append(MirrorTransform(axes=(0, 1)))

    train_transforms.append(
        BrightnessMultiplicativeTransform((0.7, 1.5),
                                          per_channel=True,
                                          p_per_sample=0.2))

    train_transforms.append(
        GammaTransform(gamma_range=(0.2, 1.0),
                       invert_image=False,
                       per_channel=False,
                       p_per_sample=0.2))

    train_transforms.append(
        GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.2))

    train_transforms.append(
        GaussianBlurTransform(blur_sigma=(0.2, 1.0),
                              different_sigma_per_channel=False,
                              p_per_channel=0.0,
                              p_per_sample=0.2))

    return Compose(train_transforms)
Пример #7
0
def get_train_transform(patch_size):
    # we now create a list of transforms. These are not necessarily the best transforms to use for BraTS, this is just
    # to showcase some things
    tr_transforms = []

    # the first thing we want to run is the SpatialTransform. It reduces the size of our data to patch_size and thus
    # also reduces the computational cost of all subsequent operations. All subsequent operations do not modify the
    # shape and do not transform spatially, so no border artifacts will be introduced
    # Here we use the new SpatialTransform_2 which uses a new way of parameterizing elastic_deform
    # We use all spatial transformations with a probability of 0.2 per sample. This means that 1 - (1 - 0.1) ** 3 = 27%
    # of samples will be augmented, the rest will just be cropped
    tr_transforms.append(
        SpatialTransform_2(
            patch_size, [i // 2 for i in patch_size],
            do_elastic_deform=True,
            deformation_scale=(0, 0.25),
            do_rotation=True,
            angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            do_scale=True,
            scale=(0.75, 1.25),
            border_mode_data='constant',
            border_cval_data=0,
            border_mode_seg='constant',
            border_cval_seg=0,
            order_seg=1,
            order_data=3,
            random_crop=True,
            p_el_per_sample=0.1,
            p_rot_per_sample=0.1,
            p_scale_per_sample=0.1))

    # now we mirror along all axes
    tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))

    # gamma transform. This is a nonlinear transformation of intensity values
    # (https://en.wikipedia.org/wiki/Gamma_correction)
    tr_transforms.append(
        GammaTransform(gamma_range=(0.5, 2),
                       invert_image=False,
                       per_channel=True,
                       p_per_sample=0.15))
    # we can also invert the image, apply the transform and then invert back
    tr_transforms.append(
        GammaTransform(gamma_range=(0.5, 2),
                       invert_image=True,
                       per_channel=True,
                       p_per_sample=0.15))

    # Gaussian Noise
    tr_transforms.append(
        GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.15))

    # now we compose these transforms together
    tr_transforms = Compose(tr_transforms)
    return tr_transforms
    def _augment_data(self, batch_generator, type=None):

        if self.Config.DATA_AUGMENTATION:
            num_processes = 16  # 2D: 8 is a bit faster than 16
            # num_processes = 8
        else:
            num_processes = 6

        tfs = []  #transforms

        if self.Config.NORMALIZE_DATA:
            tfs.append(ZeroMeanUnitVarianceTransform(per_channel=self.Config.NORMALIZE_PER_CHANNEL))

        if self.Config.DATA_AUGMENTATION:
            if type == "train":
                # scale: inverted: 0.5 -> bigger; 2 -> smaller
                # patch_center_dist_from_border: if 144/2=72 -> always exactly centered; otherwise a bit off center (brain can get off image and will be cut then)

                if self.Config.DAUG_SCALE:
                    center_dist_from_border = int(self.Config.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
                    tfs.append(SpatialTransform(self.Config.INPUT_DIM,
                                                patch_center_dist_from_border=center_dist_from_border,
                                                do_elastic_deform=self.Config.DAUG_ELASTIC_DEFORM,
                                                alpha=(90., 120.), sigma=(9., 11.),
                                                do_rotation=self.Config.DAUG_ROTATE,
                                                angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8), angle_z=(-0.8, 0.8),
                                                do_scale=True, scale=(0.9, 1.5), border_mode_data='constant',
                                                border_cval_data=0,
                                                order_data=3,
                                                border_mode_seg='constant', border_cval_seg=0,
                                                order_seg=0, random_crop=True, p_el_per_sample=0.2,
                                                p_rot_per_sample=0.2, p_scale_per_sample=0.2))

                if self.Config.DAUG_RESAMPLE:
                    tfs.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), p_per_sample=0.2))

                if self.Config.DAUG_NOISE:
                    tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.2))

                if self.Config.DAUG_MIRROR:
                    tfs.append(MirrorTransform())

                if self.Config.DAUG_FLIP_PEAKS:
                    tfs.append(FlipVectorAxisTransform())

        tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float"))

        #num_cached_per_queue 1 or 2 does not really make a difference
        batch_gen = MultiThreadedAugmenter(batch_generator, Compose(tfs), num_processes=num_processes,
                                           num_cached_per_queue=1, seeds=None, pin_memory=True)
        return batch_gen    # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
    def test_segmentations_2D(self):
        ### test whether segmentations are mirrored coherently with images

        batch_gen = BasicDataLoader((self.x_2D, self.y_2D), self.batch_size, number_of_threads_in_multithreaded=None)
        batch_gen = SingleThreadedAugmenter(batch_gen, MirrorTransform((0, 1)))

        equivalent = True

        for b in range(self.num_batches):
            batch = next(batch_gen)
            for ix in range(self.batch_size):
                if (batch['data'][ix] != batch['seg'][ix]).all():
                    equivalent = False

        self.assertTrue(equivalent, "2D images and seg were not mirrored in the same way (they should though because "
                                    "seg needs to match the corresponding data")
def create_data_gen_pipeline(cf, cities=None, data_split='train', do_aug=True, random=True, n_batches=None):
    """
    create mutli-threaded train/val/test batch generation and augmentation pipeline.
    :param cities: list of strings or None
    :param patient_data: dictionary containing one dictionary per patient in the train/test subset
    :param test_pids: (optional) list of test patient ids, calls the test generator.
    :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing)
    :param random: bool, whether to draw random batches or go through data linearly
    :return: multithreaded_generator
    """
    data_gen = BatchGenerator(cities=cities, batch_size=cf.batch_size, data_dir=cf.data_dir,
                              label_density=cf.label_density, data_split=data_split, resolution=cf.resolution,
                              gt_instances=cf.gt_instances, n_batches=n_batches, random=random)
    my_transforms = []
    if do_aug:
        mirror_transform = MirrorTransform(axes=(3,))
        my_transforms.append(mirror_transform)
        spatial_transform = SpatialTransform(patch_size=cf.patch_size[-2:],
                                             patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
                                             do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
                                             alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
                                             do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
                                             angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
                                             do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
                                             random_crop=cf.da_kwargs['random_crop'],
                                             border_mode_data=cf.da_kwargs['border_mode_data'],
                                             border_mode_seg=cf.da_kwargs['border_mode_seg'],
                                             border_cval_seg=cf.da_kwargs['border_cval_seg'])
        my_transforms.append(spatial_transform)
    else:
        my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[-2:]))

    my_transforms.append(GammaTransform(cf.da_kwargs['gamma_range'], invert_image=False, per_channel=True,
                                        retain_stats=cf.da_kwargs['gamma_retain_stats'],
                                        p_per_sample=cf.da_kwargs['p_gamma']))
    my_transforms.append(AddLossMask(cf.ignore_label))
    if cf.label_switches is not None:
        my_transforms.append(StochasticLabelSwitches(cf.name2trainId, cf.label_switches))
    all_transforms = Compose(my_transforms)
    multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers,
                                                     seeds=range(cf.n_workers))
    return multithreaded_generator
Пример #11
0
def combine_transform():
    '''
    组合变换:对比度+镜像
    :return:
    '''
    my_transforms = []
    # 对比度变换
    brightness_transform = ContrastAugmentationTransform((0.3, 3.),
                                                         preserve_range=True)
    my_transforms.append(brightness_transform)

    # 镜像变换
    mirror_transform = MirrorTransform(axes=(2, 3))
    my_transforms.append(mirror_transform)

    all_transform = Compose(my_transforms)

    batchgen = my_data_loader.DataLoader(camera(), 4)
    multithreaded_generator = MultiThreadedAugmenter(batchgen, all_transform,
                                                     4, 2)

    # 显示转换效果
    my_data_loader.plot_batch(multithreaded_generator.__next__())
Пример #12
0
    def test_image_pipeline(self):
        '''
        This just should not crash
        :return:
        '''

        tr_transforms = []
        tr_transforms.append(MirrorTransform())
        tr_transforms.append(
            TransposeAxesTransform(transpose_any_of_these=(0, 1),
                                   p_per_sample=0.5))

        composed = Compose(tr_transforms)

        dl = self.dl_images
        mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, False)

        for _ in range(50):
            res = mt.next()

        # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
        # the success of the test but it does not look pretty)
        sleep(2)
def get_valid_transform(patch_size):
    """
    data augmentation for validation data
    inspired by:
    https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py
    :param patch_size: shape of network's input
    :return list of transformations
    """

    train_transforms = []

    train_transforms.append(
        SpatialTransform_2(patch_size,
                           patch_size,
                           do_elastic_deform=False,
                           deformation_scale=(0, 0),
                           do_rotation=False,
                           angle_x=(0, 0),
                           angle_y=(0, 0),
                           angle_z=(0, 0),
                           do_scale=False,
                           scale=(1.0, 1.0),
                           border_mode_data='constant',
                           border_cval_data=0,
                           border_mode_seg='constant',
                           border_cval_seg=0,
                           order_seg=1,
                           order_data=3,
                           random_crop=True,
                           p_el_per_sample=0.1,
                           p_rot_per_sample=0.1,
                           p_scale_per_sample=0.1))

    train_transforms.append(MirrorTransform(axes=(0, 1)))

    return Compose(train_transforms)
Пример #14
0
    def _augment_data(self, batch_generator, type=None):

        if self.Config.DATA_AUGMENTATION:
            num_processes = 15  # 15 is a bit faster than 8 on cluster
            # num_processes = multiprocessing.cpu_count()  # on cluster: gives all cores, not only assigned cores
        else:
            num_processes = 6

        tfs = []

        if self.Config.NORMALIZE_DATA:
            tfs.append(ZeroMeanUnitVarianceTransform(per_channel=self.Config.NORMALIZE_PER_CHANNEL))

        if self.Config.SPATIAL_TRANSFORM == "SpatialTransformPeaks":
            SpatialTransformUsed = SpatialTransformPeaks
        elif self.Config.SPATIAL_TRANSFORM == "SpatialTransformCustom":
            SpatialTransformUsed = SpatialTransformCustom
        else:
            SpatialTransformUsed = SpatialTransform

        if self.Config.DATA_AUGMENTATION:
            if type == "train":
                # patch_center_dist_from_border:
                #   if 144/2=72 -> always exactly centered; otherwise a bit off center
                #   (brain can get off image and will be cut then)
                if self.Config.DAUG_SCALE:

                    if self.Config.INPUT_RESCALING:
                        source_mm = 2  # for bb
                        target_mm = float(self.Config.RESOLUTION[:-2])
                        scale_factor = target_mm / source_mm
                        scale = (scale_factor, scale_factor)
                    else:
                        scale = (0.9, 1.5)

                    if self.Config.PAD_TO_SQUARE:
                        patch_size = self.Config.INPUT_DIM
                    else:
                        patch_size = None  # keeps dimensions of the data

                    # spatial transform automatically crops/pads to correct size
                    center_dist_from_border = int(self.Config.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
                    tfs.append(SpatialTransformUsed(patch_size,
                                                patch_center_dist_from_border=center_dist_from_border,
                                                do_elastic_deform=self.Config.DAUG_ELASTIC_DEFORM,
                                                alpha=self.Config.DAUG_ALPHA, sigma=self.Config.DAUG_SIGMA,
                                                do_rotation=self.Config.DAUG_ROTATE,
                                                angle_x=self.Config.DAUG_ROTATE_ANGLE,
                                                angle_y=self.Config.DAUG_ROTATE_ANGLE,
                                                angle_z=self.Config.DAUG_ROTATE_ANGLE,
                                                do_scale=True, scale=scale, border_mode_data='constant',
                                                border_cval_data=0,
                                                order_data=3,
                                                border_mode_seg='constant', border_cval_seg=0,
                                                order_seg=0, random_crop=True,
                                                p_el_per_sample=self.Config.P_SAMP,
                                                p_rot_per_sample=self.Config.P_SAMP,
                                                p_scale_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_RESAMPLE:
                    tfs.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), p_per_sample=0.2, per_channel=False))

                if self.Config.DAUG_RESAMPLE_LEGACY:
                    tfs.append(ResampleTransformLegacy(zoom_range=(0.5, 1)))

                if self.Config.DAUG_GAUSSIAN_BLUR:
                    tfs.append(GaussianBlurTransform(blur_sigma=self.Config.DAUG_BLUR_SIGMA,
                                                     different_sigma_per_channel=False,
                                                     p_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_NOISE:
                    tfs.append(GaussianNoiseTransform(noise_variance=self.Config.DAUG_NOISE_VARIANCE,
                                                      p_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_MIRROR:
                    tfs.append(MirrorTransform())

                if self.Config.DAUG_FLIP_PEAKS:
                    tfs.append(FlipVectorAxisTransform())

        tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float"))

        #num_cached_per_queue 1 or 2 does not really make a difference
        batch_gen = MultiThreadedAugmenter(batch_generator, Compose(tfs), num_processes=num_processes,
                                           num_cached_per_queue=1, seeds=None, pin_memory=True)
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
Пример #15
0
 def run(self, img_data, seg_data):
     # Define label for segmentation for segmentation augmentation
     if self.seg_augmentation: seg_label = "seg"
     else: seg_label = "class"
     # Create a parser for the batchgenerators module
     data_generator = DataParser(img_data, seg_data, seg_label)
     # Initialize empty transform list
     transforms = []
     # Add mirror augmentation
     if self.mirror:
         aug_mirror = MirrorTransform(axes=self.config_mirror_axes)
         transforms.append(aug_mirror)
     # Add contrast augmentation
     if self.contrast:
         aug_contrast = ContrastAugmentationTransform(
             self.config_contrast_range,
             preserve_range=self.config_contrast_preserverange,
             per_channel=self.coloraug_per_channel,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_contrast)
     # Add brightness augmentation
     if self.brightness:
         aug_brightness = BrightnessMultiplicativeTransform(
             self.config_brightness_range,
             per_channel=self.coloraug_per_channel,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_brightness)
     # Add gamma augmentation
     if self.gamma:
         aug_gamma = GammaTransform(self.config_gamma_range,
                                    invert_image=False,
                                    per_channel=self.coloraug_per_channel,
                                    retain_stats=True,
                                    p_per_sample=self.config_p_per_sample)
         transforms.append(aug_gamma)
     # Add gaussian noise augmentation
     if self.gaussian_noise:
         aug_gaussian_noise = GaussianNoiseTransform(
             self.config_gaussian_noise_range,
             p_per_sample=self.config_p_per_sample)
         transforms.append(aug_gaussian_noise)
     # Add spatial transformations as augmentation
     # (rotation, scaling, elastic deformation)
     if self.rotations or self.scaling or self.elastic_deform or \
         self.cropping:
         # Identify patch shape (full image or cropping)
         if self.cropping: patch_shape = self.cropping_patch_shape
         else: patch_shape = img_data[0].shape[0:-1]
         # Assembling the spatial transformation
         aug_spatial_transform = SpatialTransform(
             patch_shape, [i // 2 for i in patch_shape],
             do_elastic_deform=self.elastic_deform,
             alpha=self.config_elastic_deform_alpha,
             sigma=self.config_elastic_deform_sigma,
             do_rotation=self.rotations,
             angle_x=self.config_rotations_angleX,
             angle_y=self.config_rotations_angleY,
             angle_z=self.config_rotations_angleZ,
             do_scale=self.scaling,
             scale=self.config_scaling_range,
             border_mode_data='constant',
             border_cval_data=0,
             border_mode_seg='constant',
             border_cval_seg=0,
             order_data=3,
             order_seg=0,
             p_el_per_sample=self.config_p_per_sample,
             p_rot_per_sample=self.config_p_per_sample,
             p_scale_per_sample=self.config_p_per_sample,
             random_crop=self.cropping)
         # Append spatial transformation to transformation list
         transforms.append(aug_spatial_transform)
     # Compose the batchgenerators transforms
     all_transforms = Compose(transforms)
     # Assemble transforms into a augmentation generator
     augmentation_generator = SingleThreadedAugmenter(
         data_generator, all_transforms)
     # Perform the data augmentation x times (x = cycles)
     aug_img_data = None
     aug_seg_data = None
     for i in range(0, self.cycles):
         # Run the computation process for the data augmentations
         augmentation = next(augmentation_generator)
         # Access augmentated data from the batchgenerators data structure
         if aug_img_data is None and aug_seg_data is None:
             aug_img_data = augmentation["data"]
             aug_seg_data = augmentation[seg_label]
         # Concatenate the new data augmentated data with the cached data
         else:
             aug_img_data = np.concatenate(
                 (augmentation["data"], aug_img_data), axis=0)
             aug_seg_data = np.concatenate(
                 (augmentation[seg_label], aug_seg_data), axis=0)
     # Transform data from channel-first back to channel-last structure
     # Data structure channel-first 3D:  (batch, channel, x, y, z)
     # Data structure channel-last 3D:   (batch, x, y, z, channel)
     aug_img_data = np.moveaxis(aug_img_data, 1, -1)
     aug_seg_data = np.moveaxis(aug_seg_data, 1, -1)
     # Return augmentated image and segmentation data
     return aug_img_data, aug_seg_data
Пример #16
0
def main(args):
    ########################################
    #                                      #
    #      DEFINE THE HYPERPARAMETERS      #
    #                                      #
    ########################################

    # load settings from config file
    config_file = args.get("config_file")
    config_handler = ConfigHandler()
    config_dict = config_handler(config_file)

    # some are changed rarely and given manually if required
    train_size = args.get("train_size")
    val_size = args.get("val_size")
    margin = args.get("margin")
    optimizer = args.get("optimizer")

    if optimizer != "SGD" and optimizer != "Adam":
        ValueError("Invalid optimizer")
    elif optimizer == "Adam":
        optimizer_cls = torch.optim.Adam
    else:
        optimizer_cls = torch.optim.SGD

    params = Parameters(
        fixed_params={
            "model": config_dict["model"],
            "training": {
                **config_dict["training"],
                "optimizer_cls": optimizer_cls,
                **config_dict["optimizer"],
                "criterions": {
                    "FocalLoss": losses.FocalLoss(),
                    "SmoothL1Loss": losses.SmoothL1Loss()
                },
                #  "criterions": {"FocalMSELoss": losses.FocalMSELoss(),
                #                 "SmoothL1Loss": losses.SmoothL1Loss()},
                # "lr_sched_cls": ReduceLROnPlateauCallbackPyTorch,
                # "lr_sched_params": {"verbose": True},
                "lr_sched_cls": None,
                "lr_sched_params": {},
                "metrics": {}
            }
        })

    ########################################
    #                                      #
    #        DEFINE THE AUGMENTATIONS      #
    #                                      #
    ########################################

    my_transforms = []
    mirror_transform = MirrorTransform(axes=(1, 2))
    my_transforms.append(mirror_transform)
    crop_size = config_dict["data"]["crop_size"]
    img_shape = config_dict["data"]["img_shape"]
    shape_limit = config_dict["data"]["shape_limit"]

    if (crop_size is not None and crop_size[0] == crop_size[1]) or \
        (img_shape is not None and len(img_shape) > 1
            and img_shape[0] == img_shape[1]):
        rot_transform = Rot90Transform(axes=(0, 1), p_per_sample=0.5)
        my_transforms.append(rot_transform)
    else:
        rot_transform = Rot90Transform(axes=(0, 1),
                                       num_rot=(0, 2),
                                       p_per_sample=0.5)
        my_transforms.append(rot_transform)

    # apply a more extended augmentation (if desiered)
    if "ext_aug" in config_dict["data"].keys() and \
            config_dict["data"]["ext_aug"] is not None and \
            config_dict["data"]["ext_aug"]:

        if crop_size is not None:
            size = [crop_size[0] + 25, crop_size[1] + 25]
        elif img_shape is not None:
            size = [img_shape[0] + 5, img_shape[1] + 5]
        elif shape_limit is not None:
            size = [shape_limit[0] + 5, shape_limit[1] + 5]
        else:
            raise KeyError("Crop size or image shape requried!")

        if crop_size is not None:
            spatial_transforms = SpatialTransform([size[0] - 25, size[1] - 25],
                                                  np.asarray(size) // 2,
                                                  do_elastic_deform=False,
                                                  do_rotation=True,
                                                  angle_x=(0, 0.01 * np.pi),
                                                  do_scale=True,
                                                  scale=(0.9, 1.1),
                                                  random_crop=True,
                                                  border_mode_data="mirror",
                                                  border_mode_seg="mirror")
            my_transforms.append(spatial_transforms)

        elif img_shape is not None or shape_limit is not None:
            spatial_transforms = SpatialTransform(
                [size[0] - 5, size[1] - 5],
                np.asarray(size) // 2,
                do_elastic_deform=False,
                do_rotation=False,
                #angle_x=(0, 0.01 * np.pi),
                do_scale=True,
                scale=(0.9, 1.1),
                random_crop=True,
                border_mode_data="constant",
                border_mode_seg="nearest")
            my_transforms.append(spatial_transforms)

    # bbox generation
    bb_transform = ConvertSegToBB(dim=2, margin=margin)
    my_transforms.append(bb_transform)

    transforms = Compose(my_transforms)

    ########################################
    #                                      #
    #   DEFINE THE DATASETS and MANAGER    #
    #                                      #
    ########################################

    # paths to csv files containing labels (and other information)
    csv_calc_train = '/home/temp/moriz/data/' \
                     'calc_case_description_train_set.csv'
    csv_mass_train = '/home/temp/moriz/data/' \
                     'mass_case_description_train_set.csv'

    # path to data directory
    ddsm_dir = '/home/temp/moriz/data/CBIS-DDSM/'

    # path to data directory
    inbreast_dir = '/images/Mammography/INbreast/AllDICOMs/'

    # paths to csv files containing labels (and other information)
    xls_file = '/images/Mammography/INbreast/INbreast.xls'

    # determine class and load function
    if config_dict["data"]["dataset_type"] == "INbreast":
        dataset_cls = CacheINbreastDataset
        data_dir = inbreast_dir
        csv_file = None

        if config_dict["data"]["level"] == "crops":
            load_fn = inbreast_utils.load_pos_crops
        elif config_dict["data"]["level"] == "images":
            load_fn = inbreast_utils.load_sample
        elif config_dict["data"]["level"] == "both":
            #TODO: fix
            load_fn = inbreast_utils.load_sample_and_crops
        else:
            raise TypeError("Level required!")
    elif config_dict["data"]["dataset_type"] == "DDSM":
        data_dir = ddsm_dir

        if config_dict["data"]["level"] == "crops":
            load_fn = ddsm_utils.load_pos_crops
        elif config_dict["data"]["level"] == "images":
            load_fn = ddsm_utils.load_sample
        elif config_dict["data"]["level"] == "images+":
            load_fn = ddsm_utils.load_sample_with_crops
        else:
            raise TypeError("Level required!")

        if config_dict["data"]["type"] == "mass":
            csv_file = csv_mass_train
        elif config_dict["data"]["type"] == "calc":
            csv_file = csv_calc_train
        elif config_dict["data"]["type"] == "both":
            raise NotImplementedError("Todo")
        else:
            raise TypeError("Unknown lesion type!")

        if "mode" in config_dict["data"].keys():
            if config_dict["data"]["mode"] == "lazy":
                dataset_cls = LazyDDSMDataset

                if config_dict["data"]["level"] == "crops":
                    load_fn = ddsm_utils.load_single_pos_crops

            elif config_dict["data"]["mode"] == "cache":
                dataset_cls = CacheDDSMDataset
            else:
                raise TypeError("Unsupported loading mode!")
        else:
            dataset_cls = CacheDDSMDataset

    else:
        raise TypeError("Dataset is not supported!")

    dataset_train_dict = {
        'data_path': data_dir,
        'xls_file': xls_file,
        'csv_file': csv_file,
        'load_fn': load_fn,
        'num_elements': config_dict["debug"]["n_train"],
        **config_dict["data"]
    }

    dataset_val_dict = {
        'data_path': data_dir,
        'xls_file': xls_file,
        'csv_file': csv_file,
        'load_fn': load_fn,
        'num_elements': config_dict["debug"]["n_val"],
        **config_dict["data"]
    }

    datamgr_train_dict = {
        'batch_size': params.nested_get("batch_size"),
        'n_process_augmentation': 4,
        'transforms': transforms,
        'sampler_cls': RandomSampler,
        'data_loader_cls': BaseDataLoader
    }

    datamgr_val_dict = {
        'batch_size': params.nested_get("batch_size"),
        'n_process_augmentation': 4,
        'transforms': transforms,
        'sampler_cls': SequentialSampler,
        'data_loader_cls': BaseDataLoader
    }

    ########################################
    #                                      #
    #   INITIALIZE THE ACTUAL EXPERIMENT   #
    #                                      #
    ########################################
    checkpoint_path = config_dict["checkpoint_path"]["path"]
    # if "checkpoint_path" in args and args["checkpoint_path"] is not None:
    #     checkpoint_path = args.get("checkpoint_path")

    experiment = \
        RetinaNetExperiment(params,
                            RetinaNet,
                            name = config_dict["logging"]["name"],
                            save_path = checkpoint_path,
                            dataset_cls=dataset_cls,
                            dataset_train_kwargs=dataset_train_dict,
                            datamgr_train_kwargs=datamgr_train_dict,
                            dataset_val_kwargs=dataset_val_dict,
                            datamgr_val_kwargs=datamgr_val_dict,
                            optim_builder=create_optims_default_pytorch,
                            gpu_ids=list(range(args.get('gpus'))),
                            val_score_key="val_FocalLoss",
                            val_score_mode="lowest",
                            checkpoint_freq=2)

    ########################################
    #                                      #
    # LOGGING DEFINITION AND CONFIGURATION #
    #                                      #
    ########################################

    logger_kwargs = config_dict["logging"]

    # setup initial logging
    log_file = os.path.join(experiment.save_path, 'logger.log')

    logging.basicConfig(level=logging.INFO,
                        handlers=[
                            TrixiHandler(PytorchVisdomLogger,
                                         **config_dict["logging"]),
                            logging.StreamHandler(),
                            logging.FileHandler(log_file)
                        ])

    logger = logging.getLogger("RetinaNet Logger")

    with open(experiment.save_path + "/config.yml", 'w') as file:
        yaml.dump(config_dict, file)

    ########################################
    #                                      #
    #       LOAD PATHS AND EXECUTE MODEL   #
    #                                      #
    ########################################
    seed = config_dict["data"]["seed"]

    if "train_size" in config_dict["data"].keys():
        train_size = config_dict["data"]["train_size"]

    if "val_size" in config_dict["data"].keys():
        val_size = config_dict["data"]["val_size"]

    if config_dict["data"]["dataset_type"] == "INbreast":
        if not config_dict["kfold"]["enable"]:


            train_paths, _, val_paths = \
                inbreast_utils.load_single_set(inbreast_dir,
                                               xls_file=xls_file,
                                               train_size=train_size,
                                               val_size=val_size,
                                               type=config_dict["data"]["type"],
                                               random_state=seed)

            if img_shape is not None or crop_size is not None:
                experiment.run(train_paths, val_paths)
            else:
                experiment.run(train_paths, None)

        else:
            paths = inbreast_utils.get_paths(inbreast_dir,
                                             xls_file=xls_file,
                                             type=config_dict["data"]["type"])

            if "splits" in config_dict["kfold"].keys():
                num_splits = config_dict["kfold"]["splits"]
            else:
                num_splits = 5

            experiment.kfold(paths,
                             num_splits=num_splits,
                             random_seed=seed,
                             dataset_type="INbreast")

    else:
        train_paths, val_paths, _ = \
            ddsm_utils.load_single_set(ddsm_dir,
                                       csv_file=csv_file,
                                       train_size=train_size,
                                       val_size=None,
                                       random_state=seed)

        if img_shape is not None or crop_size is not None:
            experiment.run(train_paths, val_paths)
        else:
            experiment.run(train_paths, None)
Пример #17
0
    def get_batches(self,
                    batch_size=128,
                    type=None,
                    subjects=None,
                    num_batches=None):
        data = subjects
        seg = []

        #6 -> >30GB RAM
        if self.HP.DATA_AUGMENTATION:
            num_processes = 8  # 6 is a bit faster than 16
        else:
            num_processes = 6

        nr_of_samples = len(subjects) * self.HP.INPUT_DIM[0]
        if num_batches is None:
            num_batches_multithr = int(
                nr_of_samples / batch_size /
                num_processes)  #number of batches for exactly one epoch
        else:
            num_batches_multithr = int(num_batches / num_processes)

        if self.HP.TYPE == "combined":
            # Simple with .npy  -> just a little bit faster than Nifti (<10%) and f1 not better => use Nifti
            # batch_gen = SlicesBatchGeneratorRandomNpyImg_fusion((data, seg), batch_size=batch_size)
            batch_gen = SlicesBatchGeneratorRandomNpyImg_fusion(
                (data, seg), batch_size=batch_size)
        else:
            batch_gen = SlicesBatchGeneratorRandomNiftiImg(
                (data, seg), batch_size=batch_size)
            # batch_gen = SlicesBatchGeneratorRandomNiftiImg_5slices((data, seg), batch_size=batch_size)

        batch_gen.HP = self.HP
        tfs = []  #transforms

        if self.HP.NORMALIZE_DATA:
            tfs.append(
                ZeroMeanUnitVarianceTransform(
                    per_channel=self.HP.NORMALIZE_PER_CHANNEL))

        if self.HP.DATASET == "Schizo" and self.HP.RESOLUTION == "2mm":
            tfs.append(PadToMultipleTransform(16))

        if self.HP.DATA_AUGMENTATION:
            if type == "train":
                # scale: inverted: 0.5 -> bigger; 2 -> smaller
                # patch_center_dist_from_border: if 144/2=72 -> always exactly centered; otherwise a bit off center (brain can get off image and will be cut then)

                if self.HP.DAUG_SCALE:
                    center_dist_from_border = int(
                        self.HP.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
                    tfs.append(
                        SpatialTransform(
                            self.HP.INPUT_DIM,
                            patch_center_dist_from_border=
                            center_dist_from_border,
                            do_elastic_deform=self.HP.DAUG_ELASTIC_DEFORM,
                            alpha=(90., 120.),
                            sigma=(9., 11.),
                            do_rotation=self.HP.DAUG_ROTATE,
                            angle_x=(-0.8, 0.8),
                            angle_y=(-0.8, 0.8),
                            angle_z=(-0.8, 0.8),
                            do_scale=True,
                            scale=(0.9, 1.5),
                            border_mode_data='constant',
                            border_cval_data=0,
                            order_data=3,
                            border_mode_seg='constant',
                            border_cval_seg=0,
                            order_seg=0,
                            random_crop=True))

                if self.HP.DAUG_RESAMPLE:
                    tfs.append(ResampleTransform(zoom_range=(0.5, 1)))

                if self.HP.DAUG_NOISE:
                    tfs.append(GaussianNoiseTransform(noise_variance=(0,
                                                                      0.05)))

                if self.HP.DAUG_MIRROR:
                    tfs.append(MirrorTransform())

                if self.HP.DAUG_FLIP_PEAKS:
                    tfs.append(FlipVectorAxisTransform())

        #num_cached_per_queue 1 or 2 does not really make a difference
        batch_gen = MultiThreadedAugmenter(batch_gen,
                                           Compose(tfs),
                                           num_processes=num_processes,
                                           num_cached_per_queue=1,
                                           seeds=None)
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
Пример #18
0
    def _make_training_transforms(self):
        if self.no_data_augmentation:
            print("No data augmentation will be performed during training!")
            return []

        patch_size = self.patch_size[::-1]  # (x, y, z) order
        rot_angle_x = self.training_augmentation_args.get('angle_x', 15)
        rot_angle_y = self.training_augmentation_args.get('angle_y', 15)
        rot_angle_z = self.training_augmentation_args.get('angle_z', 15)
        p_per_sample = self.training_augmentation_args.get(
            'p_per_sample', 0.15)

        train_transforms = [
            SpatialTransform_2(
                patch_size,
                patch_size // 2,
                do_elastic_deform=self.training_augmentation_args.get(
                    'do_elastic_deform', True),
                deformation_scale=self.training_augmentation_args.get(
                    'deformation_scale', (0, 0.25)),
                do_rotation=self.training_augmentation_args.get(
                    'do_rotation', True),
                angle_x=(-rot_angle_x / 360. * 2 * np.pi,
                         rot_angle_x / 360. * 2 * np.pi),
                angle_y=(-rot_angle_y / 360. * 2 * np.pi,
                         rot_angle_y / 360. * 2 * np.pi),
                angle_z=(-rot_angle_z / 360. * 2 * np.pi,
                         rot_angle_z / 360. * 2 * np.pi),
                do_scale=self.training_augmentation_args.get('do_scale', True),
                scale=self.training_augmentation_args.get(
                    'scale', (0.75, 1.25)),
                border_mode_data='nearest',
                border_cval_data=0,
                order_data=3,
                # border_mode_seg='nearest', border_cval_seg=0,
                # order_seg=0,
                random_crop=False,
                p_el_per_sample=self.training_augmentation_args.get(
                    'p_el_per_sample', 0.5),
                p_rot_per_sample=self.training_augmentation_args.get(
                    'p_rot_per_sample', 0.5),
                p_scale_per_sample=self.training_augmentation_args.get(
                    'p_scale_per_sample', 0.5))
        ]

        if self.training_augmentation_args.get("do_mirror", False):
            train_transforms.append(MirrorTransform(axes=(0, 1, 2)))

        train_transforms.append(
            BrightnessMultiplicativeTransform(
                self.training_augmentation_args.get('brightness_range',
                                                    (0.7, 1.5)),
                per_channel=True,
                p_per_sample=p_per_sample))
        train_transforms.append(
            GaussianNoiseTransform(
                noise_variance=self.training_augmentation_args.get(
                    'gaussian_noise_variance', (0, 0.05)),
                p_per_sample=p_per_sample))
        train_transforms.append(
            GammaTransform(gamma_range=self.training_augmentation_args.get(
                'gamma_range', (0.5, 2)),
                           invert_image=False,
                           per_channel=True,
                           p_per_sample=p_per_sample))

        print("train_transforms\n", train_transforms)

        return train_transforms
Пример #19
0
train_transform = transforms.Compose([
    # mt_transforms.CenterCrop2D((200, 200)),
    mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
                                   sigma_range=(3.5, 4.0),
                                   p=0.3),
    mt_transforms.RandomAffine(degrees=4.6,
                               scale=(0.98, 1.02),
                               translate=(0.03, 0.03)),
    mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
    mt_transforms.ToTensor()
    # mt_transforms.NormalizeInstance(),
])

gamma_t = GammaTransform(data_key="img", gamma_range=(0.1, 10))

mirror_t = MirrorTransform(data_key="img", label_key="seg")

spatial_t = SpatialTransform(patch_size=(8, 8, 8),
                             data_key="img",
                             label_key="seg")

gauss_noise_t = GaussianNoiseTransform(data_key="img", noise_variance=(0, 1))

zoom_t = ZoomTransform(zoom_factors=2, data_key="img")


def show_basic(x, gt, info=None):
    if info is not None:
        print("Test for " + info)

    print("img size: {}, max: {}, min: {}, avg: {}.".format(
def get_default_augmentation(dataloader_train,
                             dataloader_val,
                             patch_size,
                             params=default_3D_augmentation_params,
                             border_val_seg=-1,
                             pin_memory=True,
                             seeds_train=None,
                             seeds_val=None,
                             regions=None):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"
    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert3DTo2DTransform())
        patch_size_spatial = patch_size[1:]
    else:
        patch_size_spatial = patch_size

    # Set order_data=0 and order_seg=0 for some more speed for cascade???
    tr_transforms.append(
        SpatialTransform(patch_size_spatial,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=3,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=1,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis")))
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get("cascade_do_cascade_augmentations"
                      ) and not None and params.get(
                          "cascade_do_cascade_augmentations"):
            # Remove the following transforms to remove cascade DA ??
            tr_transforms.append(
                ApplyRandomBinaryOperatorTransform(
                    channel_idx=list(
                        range(-len(params.get("all_segmentation_labels")), 0)),
                    p_per_sample=params.get(
                        "cascade_random_binary_transform_p"),
                    key="data",
                    strel_size=params.get(
                        "cascade_random_binary_transform_size")))
            tr_transforms.append(
                RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                    channel_idx=list(
                        range(-len(params.get("all_segmentation_labels")), 0)),
                    key="data",
                    p_per_sample=params.get("cascade_remove_conn_comp_p"),
                    fill_with_other_class_p=params.get(
                        "cascade_remove_conn_comp_max_size_percent_threshold"),
                    dont_do_if_covers_more_than_X_percent=params.get(
                        "cascade_remove_conn_comp_fill_with_other_class_p")))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))

    tr_transforms = Compose(tr_transforms)
    # from batchgenerators.dataloading import SingleThreadedAugmenter
    # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
    # import IPython;IPython.embed()

    batchgenerator_train = MultiThreadedAugmenter(
        dataloader_train,
        tr_transforms,
        params.get('num_threads'),
        params.get("num_cached_per_thread"),
        seeds=seeds_train,
        pin_memory=pin_memory)

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)
    batchgenerator_val = MultiThreadedAugmenter(
        dataloader_val,
        val_transforms,
        max(params.get('num_threads') // 2, 1),
        params.get("num_cached_per_thread"),
        seeds=seeds_val,
        pin_memory=pin_memory)
    return batchgenerator_train, batchgenerator_val
             },
             encode_block=ResBlockStack,
             encode_kwargs_fn=encode_kwargs_fn,
             decode_block=ResBlock).cuda()

patch_size = (160, 160, 80)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                 factor=0.2,
                                                 patience=30)

tr_transform = Compose([
    GammaTransform((0.9, 1.1)),
    ContrastAugmentationTransform((0.9, 1.1)),
    BrightnessMultiplicativeTransform((0.9, 1.1)),
    MirrorTransform(axes=[0]),
    SpatialTransform_2(
        patch_size,
        (90, 90, 50),
        random_crop=True,
        do_elastic_deform=True,
        deformation_scale=(0, 0.05),
        do_rotation=True,
        angle_x=(-0.1 * np.pi, 0.1 * np.pi),
        angle_y=(0, 0),
        angle_z=(0, 0),
        do_scale=True,
        scale=(0.9, 1.1),
        border_mode_data='constant',
    ),
    RandomCropTransform(crop_size=patch_size),
Пример #22
0
def get_moreDA_augmentation(dataloader_train,
                            dataloader_val,
                            patch_size,
                            params=default_3D_augmentation_params,
                            border_val_seg=-1,
                            seeds_train=None,
                            seeds_val=None,
                            order_seg=1,
                            order_data=3,
                            deep_supervision_scales=None,
                            soft_ds=False,
                            classes=None,
                            pin_memory=True,
                            regions=None,
                            use_nondetMultiThreadedAugmenter: bool = False):
    assert params.get(
        'mirror') is None, "old version of params, use new keyword do_mirror"

    tr_transforms = []

    if params.get("selected_data_channels") is not None:
        tr_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))

    if params.get("selected_seg_channels") is not None:
        tr_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
    if params.get("dummy_2D") is not None and params.get("dummy_2D"):
        ignore_axes = (0, )
        tr_transforms.append(Convert3DTo2DTransform())
        patch_size_spatial = patch_size[1:]
    else:
        patch_size_spatial = patch_size
        ignore_axes = None

    tr_transforms.append(
        SpatialTransform(patch_size_spatial,
                         patch_center_dist_from_border=None,
                         do_elastic_deform=params.get("do_elastic"),
                         alpha=params.get("elastic_deform_alpha"),
                         sigma=params.get("elastic_deform_sigma"),
                         do_rotation=params.get("do_rotation"),
                         angle_x=params.get("rotation_x"),
                         angle_y=params.get("rotation_y"),
                         angle_z=params.get("rotation_z"),
                         p_rot_per_axis=params.get("rotation_p_per_axis"),
                         do_scale=params.get("do_scaling"),
                         scale=params.get("scale_range"),
                         border_mode_data=params.get("border_mode_data"),
                         border_cval_data=0,
                         order_data=order_data,
                         border_mode_seg="constant",
                         border_cval_seg=border_val_seg,
                         order_seg=order_seg,
                         random_crop=params.get("random_crop"),
                         p_el_per_sample=params.get("p_eldef"),
                         p_scale_per_sample=params.get("p_scale"),
                         p_rot_per_sample=params.get("p_rot"),
                         independent_scale_for_each_axis=params.get(
                             "independent_scale_factor_for_each_axis")))

    if params.get("dummy_2D"):
        tr_transforms.append(Convert2DTo3DTransform())

    # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
    # channel gets in the way
    tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
    tr_transforms.append(
        GaussianBlurTransform((0.5, 1.),
                              different_sigma_per_channel=True,
                              p_per_sample=0.2,
                              p_per_channel=0.5))
    tr_transforms.append(
        BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25),
                                          p_per_sample=0.15))

    if params.get("do_additive_brightness"):
        tr_transforms.append(
            BrightnessTransform(
                params.get("additive_brightness_mu"),
                params.get("additive_brightness_sigma"),
                True,
                p_per_sample=params.get("additive_brightness_p_per_sample"),
                p_per_channel=params.get("additive_brightness_p_per_channel")))

    tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
    tr_transforms.append(
        SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                       per_channel=True,
                                       p_per_channel=0.5,
                                       order_downsample=0,
                                       order_upsample=3,
                                       p_per_sample=0.25,
                                       ignore_axes=ignore_axes))
    tr_transforms.append(
        GammaTransform(params.get("gamma_range"),
                       True,
                       True,
                       retain_stats=params.get("gamma_retain_stats"),
                       p_per_sample=0.1))  # inverted gamma

    if params.get("do_gamma"):
        tr_transforms.append(
            GammaTransform(params.get("gamma_range"),
                           False,
                           True,
                           retain_stats=params.get("gamma_retain_stats"),
                           p_per_sample=params["p_gamma"]))

    if params.get("do_mirror") or params.get("mirror"):
        tr_transforms.append(MirrorTransform(params.get("mirror_axes")))

    if params.get("mask_was_used_for_normalization") is not None:
        mask_was_used_for_normalization = params.get(
            "mask_was_used_for_normalization")
        tr_transforms.append(
            MaskTransform(mask_was_used_for_normalization,
                          mask_idx_in_seg=0,
                          set_outside_to=0))

    tr_transforms.append(RemoveLabelTransform(-1, 0))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        tr_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))
        if params.get(
                "cascade_do_cascade_augmentations") is not None and params.get(
                    "cascade_do_cascade_augmentations"):
            if params.get("cascade_random_binary_transform_p") > 0:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        p_per_sample=params.get(
                            "cascade_random_binary_transform_p"),
                        key="data",
                        strel_size=params.get(
                            "cascade_random_binary_transform_size"),
                        p_per_label=params.get(
                            "cascade_random_binary_transform_p_per_label")))
            if params.get("cascade_remove_conn_comp_p") > 0:
                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(
                            range(-len(params.get("all_segmentation_labels")),
                                  0)),
                        key="data",
                        p_per_sample=params.get("cascade_remove_conn_comp_p"),
                        fill_with_other_class_p=params.get(
                            "cascade_remove_conn_comp_max_size_percent_threshold"
                        ),
                        dont_do_if_covers_more_than_X_percent=params.get(
                            "cascade_remove_conn_comp_fill_with_other_class_p")
                    ))

    tr_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        tr_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            tr_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            tr_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    tr_transforms = Compose(tr_transforms)

    if use_nondetMultiThreadedAugmenter:
        if NonDetMultiThreadedAugmenter is None:
            raise RuntimeError(
                'NonDetMultiThreadedAugmenter is not yet available')
        batchgenerator_train = NonDetMultiThreadedAugmenter(
            dataloader_train,
            tr_transforms,
            params.get('num_threads'),
            params.get("num_cached_per_thread"),
            seeds=seeds_train,
            pin_memory=pin_memory)
    else:
        batchgenerator_train = MultiThreadedAugmenter(
            dataloader_train,
            tr_transforms,
            params.get('num_threads'),
            params.get("num_cached_per_thread"),
            seeds=seeds_train,
            pin_memory=pin_memory)
    # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
    # import IPython;IPython.embed()

    val_transforms = []
    val_transforms.append(RemoveLabelTransform(-1, 0))
    if params.get("selected_data_channels") is not None:
        val_transforms.append(
            DataChannelSelectionTransform(
                params.get("selected_data_channels")))
    if params.get("selected_seg_channels") is not None:
        val_transforms.append(
            SegChannelSelectionTransform(params.get("selected_seg_channels")))

    if params.get("move_last_seg_chanel_to_data") is not None and params.get(
            "move_last_seg_chanel_to_data"):
        val_transforms.append(
            MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"),
                                  'seg', 'data'))

    val_transforms.append(RenameTransform('seg', 'target', True))

    if regions is not None:
        val_transforms.append(
            ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))

    if deep_supervision_scales is not None:
        if soft_ds:
            assert classes is not None
            val_transforms.append(
                DownsampleSegForDSTransform3(deep_supervision_scales, 'target',
                                             'target', classes))
        else:
            val_transforms.append(
                DownsampleSegForDSTransform2(deep_supervision_scales,
                                             0,
                                             input_key='target',
                                             output_key='target'))

    val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
    val_transforms = Compose(val_transforms)

    if use_nondetMultiThreadedAugmenter:
        if NonDetMultiThreadedAugmenter is None:
            raise RuntimeError(
                'NonDetMultiThreadedAugmenter is not yet available')
        batchgenerator_val = NonDetMultiThreadedAugmenter(
            dataloader_val,
            val_transforms,
            max(params.get('num_threads') // 2, 1),
            params.get("num_cached_per_thread"),
            seeds=seeds_val,
            pin_memory=pin_memory)
    else:
        batchgenerator_val = MultiThreadedAugmenter(
            dataloader_val,
            val_transforms,
            max(params.get('num_threads') // 2, 1),
            params.get("num_cached_per_thread"),
            seeds=seeds_val,
            pin_memory=pin_memory)
    # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)

    return batchgenerator_train, batchgenerator_val
    do_rotation=True,
    angle_z=(0, 2 * np.pi),  # 旋转
    do_scale=True,
    scale=(0.3, 3.),  # 缩放
    border_mode_data='constant',
    border_cval_data=0,
    order_data=1,
    random_crop=False)
my_transforms.append(spatial_transform)
GaussianNoise = GaussianNoiseTransform()  # 高斯噪声
my_transforms.append(GaussianNoise)
GaussianBlur = GaussianBlurTransform()  # 高斯模糊
my_transforms.append(GaussianBlur)
Brightness = BrightnessTransform(0, 0.2)  # 亮度
my_transforms.append(Brightness)
brightness_transform = ContrastAugmentationTransform(
    (0.3, 3.), preserve_range=True)  # 对比度
my_transforms.append(brightness_transform)
SimulateLowResolution = SimulateLowResolutionTransform()  # 低分辨率
my_transforms.append(SimulateLowResolution)
Gamma = GammaTransform()  # 伽马增强
my_transforms.append(Gamma)
mirror_transform = MirrorTransform(axes=(0, 1))  # 镜像
my_transforms.append(mirror_transform)
all_transforms = Compose(my_transforms)
multithreaded_generator = MultiThreadedAugmenter(batchgen, all_transforms, 1,
                                                 2)

t = multithreaded_generator.next()
plot_batch(t)
Пример #24
0
    angle_z=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi),
    do_scale=True,
    scale=(0.9, 1.02),
    border_mode_data='constant',
    border_cval_data=0,
    border_mode_seg='constant',
    border_cval_seg=0,
    order_seg=1,
    order_data=3,
    random_crop=False,
    p_el_per_sample=0.1,
    p_rot_per_sample=0.1,
    p_scale_per_sample=0.1)

my_transforms.append(spatial_transform)
my_transforms.append(MirrorTransform(axes=(0, 1, 2)))
my_transforms.append(
    GammaTransform(gamma_range=(0.7, 1.),
                   invert_image=False,
                   per_channel=True,
                   p_per_sample=0.1))

all_transforms = Compose(my_transforms)

train_loader = SingleThreadedAugmenter(
    batchgen, all_transforms
)  #data loader for training, applying on the fly transformation

# add other data loaders
test_loader = torch.utils.data.DataLoader(
    dataset_test,
Пример #25
0
    def get_train_transforms(self) -> List[AbstractTransform]:
        # used for transpost and rot90
        matching_axes = np.array(
            [sum([i == j for j in self.patch_size]) for i in self.patch_size])
        valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0])

        tr_transforms = []

        if self.data_aug_params['selected_seg_channels'] is not None:
            tr_transforms.append(
                SegChannelSelectionTransform(
                    self.data_aug_params['selected_seg_channels']))

        if self.do_dummy_2D_aug:
            ignore_axes = (0, )
            tr_transforms.append(Convert3DTo2DTransform())
            patch_size_spatial = self.patch_size[1:]
        else:
            patch_size_spatial = self.patch_size
            ignore_axes = None

        tr_transforms.append(
            SpatialTransform(
                patch_size_spatial,
                patch_center_dist_from_border=None,
                do_elastic_deform=False,
                do_rotation=True,
                angle_x=self.data_aug_params["rotation_x"],
                angle_y=self.data_aug_params["rotation_y"],
                angle_z=self.data_aug_params["rotation_z"],
                p_rot_per_axis=0.5,
                do_scale=True,
                scale=self.data_aug_params['scale_range'],
                border_mode_data="constant",
                border_cval_data=0,
                order_data=3,
                border_mode_seg="constant",
                border_cval_seg=-1,
                order_seg=1,
                random_crop=False,
                p_el_per_sample=0.2,
                p_scale_per_sample=0.2,
                p_rot_per_sample=0.4,
                independent_scale_for_each_axis=True,
            ))

        if self.do_dummy_2D_aug:
            tr_transforms.append(Convert2DTo3DTransform())

        if np.any(matching_axes > 1):
            tr_transforms.append(
                Rot90Transform((0, 1, 2, 3),
                               axes=valid_axes,
                               data_key='data',
                               label_key='seg',
                               p_per_sample=0.5), )

        if np.any(matching_axes > 1):
            tr_transforms.append(
                TransposeAxesTransform(valid_axes,
                                       data_key='data',
                                       label_key='seg',
                                       p_per_sample=0.5))

        tr_transforms.append(
            OneOfTransform([
                MedianFilterTransform((2, 8),
                                      same_for_each_channel=False,
                                      p_per_sample=0.2,
                                      p_per_channel=0.5),
                GaussianBlurTransform((0.3, 1.5),
                                      different_sigma_per_channel=True,
                                      p_per_sample=0.2,
                                      p_per_channel=0.5)
            ]))

        tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))

        tr_transforms.append(
            BrightnessTransform(0,
                                0.5,
                                per_channel=True,
                                p_per_sample=0.1,
                                p_per_channel=0.5))

        tr_transforms.append(
            OneOfTransform([
                ContrastAugmentationTransform(contrast_range=(0.5, 2),
                                              preserve_range=True,
                                              per_channel=True,
                                              data_key='data',
                                              p_per_sample=0.2,
                                              p_per_channel=0.5),
                ContrastAugmentationTransform(contrast_range=(0.5, 2),
                                              preserve_range=False,
                                              per_channel=True,
                                              data_key='data',
                                              p_per_sample=0.2,
                                              p_per_channel=0.5),
            ]))

        tr_transforms.append(
            SimulateLowResolutionTransform(zoom_range=(0.25, 1),
                                           per_channel=True,
                                           p_per_channel=0.5,
                                           order_downsample=0,
                                           order_upsample=3,
                                           p_per_sample=0.15,
                                           ignore_axes=ignore_axes))

        tr_transforms.append(
            GammaTransform((0.7, 1.5),
                           invert_image=True,
                           per_channel=True,
                           retain_stats=True,
                           p_per_sample=0.1))
        tr_transforms.append(
            GammaTransform((0.7, 1.5),
                           invert_image=True,
                           per_channel=True,
                           retain_stats=True,
                           p_per_sample=0.1))

        if self.do_mirroring:
            tr_transforms.append(MirrorTransform(self.mirror_axes))

        tr_transforms.append(
            BlankRectangleTransform([[max(1, p // 10), p // 3]
                                     for p in self.patch_size],
                                    rectangle_value=np.mean,
                                    num_rectangles=(1, 5),
                                    force_square=False,
                                    p_per_sample=0.4,
                                    p_per_channel=0.5))

        tr_transforms.append(
            BrightnessGradientAdditiveTransform(
                lambda x, y: np.exp(
                    np.random.uniform(np.log(x[y] // 6), np.log(x[y]))),
                (-0.5, 1.5),
                max_strength=lambda x, y: np.random.uniform(-5, -1)
                if np.random.uniform() < 0.5 else np.random.uniform(1, 5),
                mean_centered=False,
                same_for_all_channels=False,
                p_per_sample=0.3,
                p_per_channel=0.5))

        tr_transforms.append(
            LocalGammaTransform(
                lambda x, y: np.exp(
                    np.random.uniform(np.log(x[y] // 6), np.log(x[y]))),
                (-0.5, 1.5),
                lambda: np.random.uniform(0.01, 0.8)
                if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4),
                same_for_all_channels=False,
                p_per_sample=0.3,
                p_per_channel=0.5))

        tr_transforms.append(
            SharpeningTransform(strength=(0.1, 1),
                                same_for_each_channel=False,
                                p_per_sample=0.2,
                                p_per_channel=0.5))

        if any(self.use_mask_for_norm.values()):
            tr_transforms.append(
                MaskTransform(self.use_mask_for_norm,
                              mask_idx_in_seg=0,
                              set_outside_to=0))

        tr_transforms.append(RemoveLabelTransform(-1, 0))

        if self.data_aug_params["move_last_seg_chanel_to_data"]:
            all_class_labels = np.arange(1, self.num_classes)
            tr_transforms.append(
                MoveSegAsOneHotToData(1, all_class_labels, 'seg', 'data'))
            if self.data_aug_params["cascade_do_cascade_augmentations"]:
                tr_transforms.append(
                    ApplyRandomBinaryOperatorTransform(channel_idx=list(
                        range(-len(all_class_labels), 0)),
                                                       p_per_sample=0.4,
                                                       key="data",
                                                       strel_size=(1, 8),
                                                       p_per_label=1))

                tr_transforms.append(
                    RemoveRandomConnectedComponentFromOneHotEncodingTransform(
                        channel_idx=list(range(-len(all_class_labels), 0)),
                        key="data",
                        p_per_sample=0.2,
                        fill_with_other_class_p=0.15,
                        dont_do_if_covers_more_than_X_percent=0))

        tr_transforms.append(RenameTransform('seg', 'target', True))

        if self.regions is not None:
            tr_transforms.append(
                ConvertSegmentationToRegionsTransform(self.regions, 'target',
                                                      'target'))

        if self.deep_supervision_scales is not None:
            tr_transforms.append(
                DownsampleSegForDSTransform2(self.deep_supervision_scales,
                                             0,
                                             input_key='target',
                                             output_key='target'))

        tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
        return tr_transforms