Exemplo n.º 1
0
    def _augment_data(self, batch_generator, type=None):
        tfs = []  # transforms

        if self.Config.NORMALIZE_DATA:
            tfs.append(
                ZeroMeanUnitVarianceTransform_Standalone(
                    per_channel=self.Config.NORMALIZE_PER_CHANNEL))

        # Not used, because those transformations are not easily invertible with batchgenerators framework:
        #  Mirroring would be the only easy test time DAug, but not trained with this DAug
        # if self.Config.TEST_TIME_DAUG:
        # from batchgenerators.transforms.spatial_transforms import SpatialTransform
        # center_dist_from_border = int(self.Config.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
        # tfs.append(SpatialTransform(self.Config.INPUT_DIM,
        #                             patch_center_dist_from_border=center_dist_from_border,
        #                             do_elastic_deform=True, alpha=(90., 120.), sigma=(9., 11.),
        #                             do_rotation=True, angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8),
        #                             angle_z=(-0.8, 0.8),
        #                             do_scale=True, scale=(0.9, 1.5), border_mode_data='constant',
        #                             border_cval_data=0,
        #                             order_data=3,
        #                             border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True))
        # tfs.append(ResampleTransform(zoom_range=(0.5, 1)))
        # tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05)))
        # tfs.append(ContrastAugmentationTransform(contrast_range=(0.7, 1.3), preserve_range=True, per_channel=False))
        # tfs.append(BrightnessMultiplicativeTransform(multiplier_range=(0.7, 1.3), per_channel=False))

        tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float"))

        batch_gen = SingleThreadedAugmenter(batch_generator, Compose(tfs))
        return batch_gen
Exemplo n.º 2
0
    def _augment_data(self, batch_generator, type=None):
        tfs = []

        if self.Config.NORMALIZE_DATA:
            tfs.append(
                ZeroMeanUnitVarianceTransform_Standalone(
                    per_channel=self.Config.NORMALIZE_PER_CHANNEL))

        tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float"))

        batch_gen = SingleThreadedAugmenter(batch_generator, Compose(tfs))
        return batch_gen
Exemplo n.º 3
0
    def _augment_data(self, batch_generator, type=None):

        if self.Config.DATA_AUGMENTATION:
            num_processes = 15  # 15 is a bit faster than 8 on cluster
            # num_processes = multiprocessing.cpu_count()  # on cluster: gives all cores, not only assigned cores
        else:
            num_processes = 6

        tfs = []

        if self.Config.NORMALIZE_DATA:
            # todo: Use original transform as soon as bug fixed in batchgenerators
            # tfs.append(ZeroMeanUnitVarianceTransform(per_channel=self.Config.NORMALIZE_PER_CHANNEL))
            tfs.append(
                ZeroMeanUnitVarianceTransform_Standalone(
                    per_channel=self.Config.NORMALIZE_PER_CHANNEL))

        if self.Config.SPATIAL_TRANSFORM == "SpatialTransformPeaks":
            SpatialTransformUsed = SpatialTransformPeaks
        elif self.Config.SPATIAL_TRANSFORM == "SpatialTransformCustom":
            SpatialTransformUsed = SpatialTransformCustom
        else:
            SpatialTransformUsed = SpatialTransform

        if self.Config.DATA_AUGMENTATION:
            if type == "train":
                # patch_center_dist_from_border:
                #   if 144/2=72 -> always exactly centered; otherwise a bit off center
                #   (brain can get off image and will be cut then)
                if self.Config.DAUG_SCALE:

                    if self.Config.INPUT_RESCALING:
                        source_mm = 2  # for bb
                        target_mm = float(self.Config.RESOLUTION[:-2])
                        scale_factor = target_mm / source_mm
                        scale = (scale_factor, scale_factor)
                    else:
                        scale = (0.9, 1.5)

                    if self.Config.PAD_TO_SQUARE:
                        patch_size = self.Config.INPUT_DIM
                    else:
                        patch_size = None  # keeps dimensions of the data

                    # spatial transform automatically crops/pads to correct size
                    center_dist_from_border = int(
                        self.Config.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
                    tfs.append(
                        SpatialTransformUsed(
                            patch_size,
                            patch_center_dist_from_border=
                            center_dist_from_border,
                            do_elastic_deform=self.Config.DAUG_ELASTIC_DEFORM,
                            alpha=self.Config.DAUG_ALPHA,
                            sigma=self.Config.DAUG_SIGMA,
                            do_rotation=self.Config.DAUG_ROTATE,
                            angle_x=self.Config.DAUG_ROTATE_ANGLE,
                            angle_y=self.Config.DAUG_ROTATE_ANGLE,
                            angle_z=self.Config.DAUG_ROTATE_ANGLE,
                            do_scale=True,
                            scale=scale,
                            border_mode_data='constant',
                            border_cval_data=0,
                            order_data=3,
                            border_mode_seg='constant',
                            border_cval_seg=0,
                            order_seg=0,
                            random_crop=True,
                            p_el_per_sample=self.Config.P_SAMP,
                            p_rot_per_sample=self.Config.P_SAMP,
                            p_scale_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_RESAMPLE:
                    tfs.append(
                        SimulateLowResolutionTransform(zoom_range=(0.5, 1),
                                                       p_per_sample=0.2,
                                                       per_channel=False))

                if self.Config.DAUG_RESAMPLE_LEGACY:
                    tfs.append(ResampleTransformLegacy(zoom_range=(0.5, 1)))

                if self.Config.DAUG_GAUSSIAN_BLUR:
                    tfs.append(
                        GaussianBlurTransform(
                            blur_sigma=self.Config.DAUG_BLUR_SIGMA,
                            different_sigma_per_channel=False,
                            p_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_NOISE:
                    tfs.append(
                        GaussianNoiseTransform(
                            noise_variance=self.Config.DAUG_NOISE_VARIANCE,
                            p_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_MIRROR:
                    tfs.append(MirrorTransform())

                if self.Config.DAUG_FLIP_PEAKS:
                    tfs.append(FlipVectorAxisTransform())

        tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float"))

        #num_cached_per_queue 1 or 2 does not really make a difference
        batch_gen = MultiThreadedAugmenter(batch_generator,
                                           Compose(tfs),
                                           num_processes=num_processes,
                                           num_cached_per_queue=1,
                                           seeds=None,
                                           pin_memory=True)
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)