Ejemplo n.º 1
0
    def get_batches(self, batch_size=1):
        data = np.nan_to_num(self.data)
        # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth))
        seg = np.zeros(
            (self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0],
             self.HP.NR_OF_CLASSES)).astype(self.HP.LABELS_TYPE)

        num_processes = 1  # not not use more than 1 if you want to keep original slice order (Threads do return in random order)
        batch_gen = SlicesBatchGenerator((data, seg), BATCH_SIZE=batch_size)
        batch_gen.HP = self.HP
        tfs = []  # transforms

        if self.HP.NORMALIZE_DATA:
            tfs.append(
                ZeroMeanUnitVarianceTransform(
                    per_channel=self.HP.NORMALIZE_PER_CHANNEL))
        tfs.append(ReorderSegTransform())
        batch_gen = MultiThreadedAugmenter(
            batch_gen,
            Compose(tfs),
            num_processes=num_processes,
            num_cached_per_queue=2,
            seeds=None
        )  # Only use num_processes=1, otherwise global_idx of SlicesBatchGenerator not working
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, x, y, channels)
Ejemplo n.º 2
0
    def _augment_data(self, batch_generator, type=None):

        if self.Config.DATA_AUGMENTATION:
            num_processes = 16  # 2D: 8 is a bit faster than 16
            # num_processes = 8
        else:
            num_processes = 6

        tfs = []  #transforms

        if self.Config.NORMALIZE_DATA:
            tfs.append(ZeroMeanUnitVarianceTransform(per_channel=self.Config.NORMALIZE_PER_CHANNEL))

        if self.Config.DATA_AUGMENTATION:
            if type == "train":
                # scale: inverted: 0.5 -> bigger; 2 -> smaller
                # patch_center_dist_from_border: if 144/2=72 -> always exactly centered; otherwise a bit off center (brain can get off image and will be cut then)

                if self.Config.DAUG_SCALE:
                    center_dist_from_border = int(self.Config.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
                    tfs.append(SpatialTransform(self.Config.INPUT_DIM,
                                                patch_center_dist_from_border=center_dist_from_border,
                                                do_elastic_deform=self.Config.DAUG_ELASTIC_DEFORM,
                                                alpha=(90., 120.), sigma=(9., 11.),
                                                do_rotation=self.Config.DAUG_ROTATE,
                                                angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8), angle_z=(-0.8, 0.8),
                                                do_scale=True, scale=(0.9, 1.5), border_mode_data='constant',
                                                border_cval_data=0,
                                                order_data=3,
                                                border_mode_seg='constant', border_cval_seg=0,
                                                order_seg=0, random_crop=True, p_el_per_sample=0.2,
                                                p_rot_per_sample=0.2, p_scale_per_sample=0.2))

                if self.Config.DAUG_RESAMPLE:
                    tfs.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), p_per_sample=0.2))

                if self.Config.DAUG_NOISE:
                    tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.2))

                if self.Config.DAUG_MIRROR:
                    tfs.append(MirrorTransform())

                if self.Config.DAUG_FLIP_PEAKS:
                    tfs.append(FlipVectorAxisTransform())

        tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float"))

        #num_cached_per_queue 1 or 2 does not really make a difference
        batch_gen = MultiThreadedAugmenter(batch_generator, Compose(tfs), num_processes=num_processes,
                                           num_cached_per_queue=1, seeds=None, pin_memory=True)
        return batch_gen    # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
Ejemplo n.º 3
0
    def _augment_data(self, batch_generator, type=None):

        if self.Config.DATA_AUGMENTATION:
            num_processes = 15  # 15 is a bit faster than 8 on cluster
            # num_processes = multiprocessing.cpu_count()  # on cluster: gives all cores, not only assigned cores
        else:
            num_processes = 6

        tfs = []

        if self.Config.NORMALIZE_DATA:
            tfs.append(ZeroMeanUnitVarianceTransform(per_channel=self.Config.NORMALIZE_PER_CHANNEL))

        if self.Config.SPATIAL_TRANSFORM == "SpatialTransformPeaks":
            SpatialTransformUsed = SpatialTransformPeaks
        elif self.Config.SPATIAL_TRANSFORM == "SpatialTransformCustom":
            SpatialTransformUsed = SpatialTransformCustom
        else:
            SpatialTransformUsed = SpatialTransform

        if self.Config.DATA_AUGMENTATION:
            if type == "train":
                # patch_center_dist_from_border:
                #   if 144/2=72 -> always exactly centered; otherwise a bit off center
                #   (brain can get off image and will be cut then)
                if self.Config.DAUG_SCALE:

                    if self.Config.INPUT_RESCALING:
                        source_mm = 2  # for bb
                        target_mm = float(self.Config.RESOLUTION[:-2])
                        scale_factor = target_mm / source_mm
                        scale = (scale_factor, scale_factor)
                    else:
                        scale = (0.9, 1.5)

                    if self.Config.PAD_TO_SQUARE:
                        patch_size = self.Config.INPUT_DIM
                    else:
                        patch_size = None  # keeps dimensions of the data

                    # spatial transform automatically crops/pads to correct size
                    center_dist_from_border = int(self.Config.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
                    tfs.append(SpatialTransformUsed(patch_size,
                                                patch_center_dist_from_border=center_dist_from_border,
                                                do_elastic_deform=self.Config.DAUG_ELASTIC_DEFORM,
                                                alpha=self.Config.DAUG_ALPHA, sigma=self.Config.DAUG_SIGMA,
                                                do_rotation=self.Config.DAUG_ROTATE,
                                                angle_x=self.Config.DAUG_ROTATE_ANGLE,
                                                angle_y=self.Config.DAUG_ROTATE_ANGLE,
                                                angle_z=self.Config.DAUG_ROTATE_ANGLE,
                                                do_scale=True, scale=scale, border_mode_data='constant',
                                                border_cval_data=0,
                                                order_data=3,
                                                border_mode_seg='constant', border_cval_seg=0,
                                                order_seg=0, random_crop=True,
                                                p_el_per_sample=self.Config.P_SAMP,
                                                p_rot_per_sample=self.Config.P_SAMP,
                                                p_scale_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_RESAMPLE:
                    tfs.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), p_per_sample=0.2, per_channel=False))

                if self.Config.DAUG_RESAMPLE_LEGACY:
                    tfs.append(ResampleTransformLegacy(zoom_range=(0.5, 1)))

                if self.Config.DAUG_GAUSSIAN_BLUR:
                    tfs.append(GaussianBlurTransform(blur_sigma=self.Config.DAUG_BLUR_SIGMA,
                                                     different_sigma_per_channel=False,
                                                     p_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_NOISE:
                    tfs.append(GaussianNoiseTransform(noise_variance=self.Config.DAUG_NOISE_VARIANCE,
                                                      p_per_sample=self.Config.P_SAMP))

                if self.Config.DAUG_MIRROR:
                    tfs.append(MirrorTransform())

                if self.Config.DAUG_FLIP_PEAKS:
                    tfs.append(FlipVectorAxisTransform())

        tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float"))

        #num_cached_per_queue 1 or 2 does not really make a difference
        batch_gen = MultiThreadedAugmenter(batch_generator, Compose(tfs), num_processes=num_processes,
                                           num_cached_per_queue=1, seeds=None, pin_memory=True)
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
Ejemplo n.º 4
0
    def get_batches(self, batch_size=1):

        num_processes = 1  # not not use more than 1 if you want to keep original slice order (Threads do return in random order)

        if self.HP.TYPE == "combined":
            # Load from Npy file for Fusion
            data = self.subject
            seg = []
            nr_of_samples = len([self.subject]) * self.HP.INPUT_DIM[0]
            num_batches = int(nr_of_samples / batch_size / num_processes)
            batch_gen = SlicesBatchGeneratorNpyImg_fusion(
                (data, seg),
                BATCH_SIZE=batch_size,
                num_batches=num_batches,
                seed=None)
        else:
            # Load Features
            if self.HP.FEATURES_FILENAME == "12g90g270g":
                data_img = nib.load(
                    join(self.data_dir, "270g_125mm_peaks.nii.gz"))
            else:
                data_img = nib.load(
                    join(self.data_dir, self.HP.FEATURES_FILENAME + ".nii.gz"))
            data = data_img.get_data()
            data = np.nan_to_num(data)
            data = DatasetUtils.scale_input_to_unet_shape(
                data, self.HP.DATASET, self.HP.RESOLUTION)
            # data = DatasetUtils.scale_input_to_unet_shape(data, "HCP_32g", "1.25mm")  #If we want to test HCP_32g on HighRes net

            #Load Segmentation
            if self.use_gt_mask:
                seg = nib.load(
                    join(self.data_dir,
                         self.HP.LABELS_FILENAME + ".nii.gz")).get_data()

                if self.HP.LABELS_FILENAME not in [
                        "bundle_peaks_11_808080", "bundle_peaks_20_808080",
                        "bundle_peaks_808080", "bundle_masks_20_808080",
                        "bundle_masks_72_808080", "bundle_peaks_Part1_808080",
                        "bundle_peaks_Part2_808080",
                        "bundle_peaks_Part3_808080",
                        "bundle_peaks_Part4_808080"
                ]:
                    if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]:
                        # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask
                        seg = DatasetUtils.scale_input_to_unet_shape(
                            seg, "HCP", self.HP.RESOLUTION)
                    else:
                        seg = DatasetUtils.scale_input_to_unet_shape(
                            seg, self.HP.DATASET, self.HP.RESOLUTION)
            else:
                # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth))
                seg = np.zeros(
                    (self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0],
                     self.HP.INPUT_DIM[0],
                     self.HP.NR_OF_CLASSES)).astype(self.HP.LABELS_TYPE)

            batch_gen = SlicesBatchGenerator((data, seg),
                                             batch_size=batch_size)

        batch_gen.HP = self.HP
        tfs = []  # transforms

        if self.HP.NORMALIZE_DATA:
            tfs.append(ZeroMeanUnitVarianceTransform(per_channel=False))

        if self.HP.TEST_TIME_DAUG:
            center_dist_from_border = int(
                self.HP.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
            tfs.append(
                SpatialTransform(
                    self.HP.INPUT_DIM,
                    patch_center_dist_from_border=center_dist_from_border,
                    do_elastic_deform=True,
                    alpha=(90., 120.),
                    sigma=(9., 11.),
                    do_rotation=True,
                    angle_x=(-0.8, 0.8),
                    angle_y=(-0.8, 0.8),
                    angle_z=(-0.8, 0.8),
                    do_scale=True,
                    scale=(0.9, 1.5),
                    border_mode_data='constant',
                    border_cval_data=0,
                    order_data=3,
                    border_mode_seg='constant',
                    border_cval_seg=0,
                    order_seg=0,
                    random_crop=True))
            # tfs.append(ResampleTransform(zoom_range=(0.5, 1)))
            # tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05)))
            tfs.append(
                ContrastAugmentationTransform(contrast_range=(0.7, 1.3),
                                              preserve_range=True,
                                              per_channel=False))
            tfs.append(
                BrightnessMultiplicativeTransform(multiplier_range=(0.7, 1.3),
                                                  per_channel=False))

        tfs.append(ReorderSegTransform())
        batch_gen = MultiThreadedAugmenter(
            batch_gen,
            Compose(tfs),
            num_processes=num_processes,
            num_cached_per_queue=2,
            seeds=None
        )  # Only use num_processes=1, otherwise global_idx of SlicesBatchGenerator not working
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, x, y, channels)
Ejemplo n.º 5
0
    def get_batches(self,
                    batch_size=128,
                    type=None,
                    subjects=None,
                    num_batches=None):
        data = subjects
        seg = []

        #6 -> >30GB RAM
        if self.HP.DATA_AUGMENTATION:
            num_processes = 8  # 6 is a bit faster than 16
        else:
            num_processes = 6

        nr_of_samples = len(subjects) * self.HP.INPUT_DIM[0]
        if num_batches is None:
            num_batches_multithr = int(
                nr_of_samples / batch_size /
                num_processes)  #number of batches for exactly one epoch
        else:
            num_batches_multithr = int(num_batches / num_processes)

        if self.HP.TYPE == "combined":
            # Simple with .npy  -> just a little bit faster than Nifti (<10%) and f1 not better => use Nifti
            # batch_gen = SlicesBatchGeneratorRandomNpyImg_fusion((data, seg), batch_size=batch_size)
            batch_gen = SlicesBatchGeneratorRandomNpyImg_fusion(
                (data, seg), batch_size=batch_size)
        else:
            batch_gen = SlicesBatchGeneratorRandomNiftiImg(
                (data, seg), batch_size=batch_size)
            # batch_gen = SlicesBatchGeneratorRandomNiftiImg_5slices((data, seg), batch_size=batch_size)

        batch_gen.HP = self.HP
        tfs = []  #transforms

        if self.HP.NORMALIZE_DATA:
            tfs.append(
                ZeroMeanUnitVarianceTransform(
                    per_channel=self.HP.NORMALIZE_PER_CHANNEL))

        if self.HP.DATASET == "Schizo" and self.HP.RESOLUTION == "2mm":
            tfs.append(PadToMultipleTransform(16))

        if self.HP.DATA_AUGMENTATION:
            if type == "train":
                # scale: inverted: 0.5 -> bigger; 2 -> smaller
                # patch_center_dist_from_border: if 144/2=72 -> always exactly centered; otherwise a bit off center (brain can get off image and will be cut then)

                if self.HP.DAUG_SCALE:
                    center_dist_from_border = int(
                        self.HP.INPUT_DIM[0] / 2.) - 10  # (144,144) -> 62
                    tfs.append(
                        SpatialTransform(
                            self.HP.INPUT_DIM,
                            patch_center_dist_from_border=
                            center_dist_from_border,
                            do_elastic_deform=self.HP.DAUG_ELASTIC_DEFORM,
                            alpha=(90., 120.),
                            sigma=(9., 11.),
                            do_rotation=self.HP.DAUG_ROTATE,
                            angle_x=(-0.8, 0.8),
                            angle_y=(-0.8, 0.8),
                            angle_z=(-0.8, 0.8),
                            do_scale=True,
                            scale=(0.9, 1.5),
                            border_mode_data='constant',
                            border_cval_data=0,
                            order_data=3,
                            border_mode_seg='constant',
                            border_cval_seg=0,
                            order_seg=0,
                            random_crop=True))

                if self.HP.DAUG_RESAMPLE:
                    tfs.append(ResampleTransform(zoom_range=(0.5, 1)))

                if self.HP.DAUG_NOISE:
                    tfs.append(GaussianNoiseTransform(noise_variance=(0,
                                                                      0.05)))

                if self.HP.DAUG_MIRROR:
                    tfs.append(MirrorTransform())

                if self.HP.DAUG_FLIP_PEAKS:
                    tfs.append(FlipVectorAxisTransform())

        #num_cached_per_queue 1 or 2 does not really make a difference
        batch_gen = MultiThreadedAugmenter(batch_gen,
                                           Compose(tfs),
                                           num_processes=num_processes,
                                           num_cached_per_queue=1,
                                           seeds=None)
        return batch_gen  # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)