def get_batches(self, batch_size=1): data = np.nan_to_num(self.data) # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth)) seg = np.zeros( (self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.NR_OF_CLASSES)).astype(self.HP.LABELS_TYPE) num_processes = 1 # not not use more than 1 if you want to keep original slice order (Threads do return in random order) batch_gen = SlicesBatchGenerator((data, seg), BATCH_SIZE=batch_size) batch_gen.HP = self.HP tfs = [] # transforms if self.HP.NORMALIZE_DATA: tfs.append( ZeroMeanUnitVarianceTransform( per_channel=self.HP.NORMALIZE_PER_CHANNEL)) tfs.append(ReorderSegTransform()) batch_gen = MultiThreadedAugmenter( batch_gen, Compose(tfs), num_processes=num_processes, num_cached_per_queue=2, seeds=None ) # Only use num_processes=1, otherwise global_idx of SlicesBatchGenerator not working return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, x, y, channels)
def _augment_data(self, batch_generator, type=None): if self.Config.DATA_AUGMENTATION: num_processes = 16 # 2D: 8 is a bit faster than 16 # num_processes = 8 else: num_processes = 6 tfs = [] #transforms if self.Config.NORMALIZE_DATA: tfs.append(ZeroMeanUnitVarianceTransform(per_channel=self.Config.NORMALIZE_PER_CHANNEL)) if self.Config.DATA_AUGMENTATION: if type == "train": # scale: inverted: 0.5 -> bigger; 2 -> smaller # patch_center_dist_from_border: if 144/2=72 -> always exactly centered; otherwise a bit off center (brain can get off image and will be cut then) if self.Config.DAUG_SCALE: center_dist_from_border = int(self.Config.INPUT_DIM[0] / 2.) - 10 # (144,144) -> 62 tfs.append(SpatialTransform(self.Config.INPUT_DIM, patch_center_dist_from_border=center_dist_from_border, do_elastic_deform=self.Config.DAUG_ELASTIC_DEFORM, alpha=(90., 120.), sigma=(9., 11.), do_rotation=self.Config.DAUG_ROTATE, angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8), angle_z=(-0.8, 0.8), do_scale=True, scale=(0.9, 1.5), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=0.2, p_rot_per_sample=0.2, p_scale_per_sample=0.2)) if self.Config.DAUG_RESAMPLE: tfs.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), p_per_sample=0.2)) if self.Config.DAUG_NOISE: tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.2)) if self.Config.DAUG_MIRROR: tfs.append(MirrorTransform()) if self.Config.DAUG_FLIP_PEAKS: tfs.append(FlipVectorAxisTransform()) tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float")) #num_cached_per_queue 1 or 2 does not really make a difference batch_gen = MultiThreadedAugmenter(batch_generator, Compose(tfs), num_processes=num_processes, num_cached_per_queue=1, seeds=None, pin_memory=True) return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
def _augment_data(self, batch_generator, type=None): if self.Config.DATA_AUGMENTATION: num_processes = 15 # 15 is a bit faster than 8 on cluster # num_processes = multiprocessing.cpu_count() # on cluster: gives all cores, not only assigned cores else: num_processes = 6 tfs = [] if self.Config.NORMALIZE_DATA: tfs.append(ZeroMeanUnitVarianceTransform(per_channel=self.Config.NORMALIZE_PER_CHANNEL)) if self.Config.SPATIAL_TRANSFORM == "SpatialTransformPeaks": SpatialTransformUsed = SpatialTransformPeaks elif self.Config.SPATIAL_TRANSFORM == "SpatialTransformCustom": SpatialTransformUsed = SpatialTransformCustom else: SpatialTransformUsed = SpatialTransform if self.Config.DATA_AUGMENTATION: if type == "train": # patch_center_dist_from_border: # if 144/2=72 -> always exactly centered; otherwise a bit off center # (brain can get off image and will be cut then) if self.Config.DAUG_SCALE: if self.Config.INPUT_RESCALING: source_mm = 2 # for bb target_mm = float(self.Config.RESOLUTION[:-2]) scale_factor = target_mm / source_mm scale = (scale_factor, scale_factor) else: scale = (0.9, 1.5) if self.Config.PAD_TO_SQUARE: patch_size = self.Config.INPUT_DIM else: patch_size = None # keeps dimensions of the data # spatial transform automatically crops/pads to correct size center_dist_from_border = int(self.Config.INPUT_DIM[0] / 2.) - 10 # (144,144) -> 62 tfs.append(SpatialTransformUsed(patch_size, patch_center_dist_from_border=center_dist_from_border, do_elastic_deform=self.Config.DAUG_ELASTIC_DEFORM, alpha=self.Config.DAUG_ALPHA, sigma=self.Config.DAUG_SIGMA, do_rotation=self.Config.DAUG_ROTATE, angle_x=self.Config.DAUG_ROTATE_ANGLE, angle_y=self.Config.DAUG_ROTATE_ANGLE, angle_z=self.Config.DAUG_ROTATE_ANGLE, do_scale=True, scale=scale, border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=self.Config.P_SAMP, p_rot_per_sample=self.Config.P_SAMP, p_scale_per_sample=self.Config.P_SAMP)) if self.Config.DAUG_RESAMPLE: tfs.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), p_per_sample=0.2, per_channel=False)) if self.Config.DAUG_RESAMPLE_LEGACY: tfs.append(ResampleTransformLegacy(zoom_range=(0.5, 1))) if self.Config.DAUG_GAUSSIAN_BLUR: tfs.append(GaussianBlurTransform(blur_sigma=self.Config.DAUG_BLUR_SIGMA, different_sigma_per_channel=False, p_per_sample=self.Config.P_SAMP)) if self.Config.DAUG_NOISE: tfs.append(GaussianNoiseTransform(noise_variance=self.Config.DAUG_NOISE_VARIANCE, p_per_sample=self.Config.P_SAMP)) if self.Config.DAUG_MIRROR: tfs.append(MirrorTransform()) if self.Config.DAUG_FLIP_PEAKS: tfs.append(FlipVectorAxisTransform()) tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float")) #num_cached_per_queue 1 or 2 does not really make a difference batch_gen = MultiThreadedAugmenter(batch_generator, Compose(tfs), num_processes=num_processes, num_cached_per_queue=1, seeds=None, pin_memory=True) return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
def get_batches(self, batch_size=1): num_processes = 1 # not not use more than 1 if you want to keep original slice order (Threads do return in random order) if self.HP.TYPE == "combined": # Load from Npy file for Fusion data = self.subject seg = [] nr_of_samples = len([self.subject]) * self.HP.INPUT_DIM[0] num_batches = int(nr_of_samples / batch_size / num_processes) batch_gen = SlicesBatchGeneratorNpyImg_fusion( (data, seg), BATCH_SIZE=batch_size, num_batches=num_batches, seed=None) else: # Load Features if self.HP.FEATURES_FILENAME == "12g90g270g": data_img = nib.load( join(self.data_dir, "270g_125mm_peaks.nii.gz")) else: data_img = nib.load( join(self.data_dir, self.HP.FEATURES_FILENAME + ".nii.gz")) data = data_img.get_data() data = np.nan_to_num(data) data = DatasetUtils.scale_input_to_unet_shape( data, self.HP.DATASET, self.HP.RESOLUTION) # data = DatasetUtils.scale_input_to_unet_shape(data, "HCP_32g", "1.25mm") #If we want to test HCP_32g on HighRes net #Load Segmentation if self.use_gt_mask: seg = nib.load( join(self.data_dir, self.HP.LABELS_FILENAME + ".nii.gz")).get_data() if self.HP.LABELS_FILENAME not in [ "bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080", "bundle_masks_20_808080", "bundle_masks_72_808080", "bundle_peaks_Part1_808080", "bundle_peaks_Part2_808080", "bundle_peaks_Part3_808080", "bundle_peaks_Part4_808080" ]: if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask seg = DatasetUtils.scale_input_to_unet_shape( seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape( seg, self.HP.DATASET, self.HP.RESOLUTION) else: # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth)) seg = np.zeros( (self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.NR_OF_CLASSES)).astype(self.HP.LABELS_TYPE) batch_gen = SlicesBatchGenerator((data, seg), batch_size=batch_size) batch_gen.HP = self.HP tfs = [] # transforms if self.HP.NORMALIZE_DATA: tfs.append(ZeroMeanUnitVarianceTransform(per_channel=False)) if self.HP.TEST_TIME_DAUG: center_dist_from_border = int( self.HP.INPUT_DIM[0] / 2.) - 10 # (144,144) -> 62 tfs.append( SpatialTransform( self.HP.INPUT_DIM, patch_center_dist_from_border=center_dist_from_border, do_elastic_deform=True, alpha=(90., 120.), sigma=(9., 11.), do_rotation=True, angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8), angle_z=(-0.8, 0.8), do_scale=True, scale=(0.9, 1.5), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True)) # tfs.append(ResampleTransform(zoom_range=(0.5, 1))) # tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05))) tfs.append( ContrastAugmentationTransform(contrast_range=(0.7, 1.3), preserve_range=True, per_channel=False)) tfs.append( BrightnessMultiplicativeTransform(multiplier_range=(0.7, 1.3), per_channel=False)) tfs.append(ReorderSegTransform()) batch_gen = MultiThreadedAugmenter( batch_gen, Compose(tfs), num_processes=num_processes, num_cached_per_queue=2, seeds=None ) # Only use num_processes=1, otherwise global_idx of SlicesBatchGenerator not working return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, x, y, channels)
def get_batches(self, batch_size=128, type=None, subjects=None, num_batches=None): data = subjects seg = [] #6 -> >30GB RAM if self.HP.DATA_AUGMENTATION: num_processes = 8 # 6 is a bit faster than 16 else: num_processes = 6 nr_of_samples = len(subjects) * self.HP.INPUT_DIM[0] if num_batches is None: num_batches_multithr = int( nr_of_samples / batch_size / num_processes) #number of batches for exactly one epoch else: num_batches_multithr = int(num_batches / num_processes) if self.HP.TYPE == "combined": # Simple with .npy -> just a little bit faster than Nifti (<10%) and f1 not better => use Nifti # batch_gen = SlicesBatchGeneratorRandomNpyImg_fusion((data, seg), batch_size=batch_size) batch_gen = SlicesBatchGeneratorRandomNpyImg_fusion( (data, seg), batch_size=batch_size) else: batch_gen = SlicesBatchGeneratorRandomNiftiImg( (data, seg), batch_size=batch_size) # batch_gen = SlicesBatchGeneratorRandomNiftiImg_5slices((data, seg), batch_size=batch_size) batch_gen.HP = self.HP tfs = [] #transforms if self.HP.NORMALIZE_DATA: tfs.append( ZeroMeanUnitVarianceTransform( per_channel=self.HP.NORMALIZE_PER_CHANNEL)) if self.HP.DATASET == "Schizo" and self.HP.RESOLUTION == "2mm": tfs.append(PadToMultipleTransform(16)) if self.HP.DATA_AUGMENTATION: if type == "train": # scale: inverted: 0.5 -> bigger; 2 -> smaller # patch_center_dist_from_border: if 144/2=72 -> always exactly centered; otherwise a bit off center (brain can get off image and will be cut then) if self.HP.DAUG_SCALE: center_dist_from_border = int( self.HP.INPUT_DIM[0] / 2.) - 10 # (144,144) -> 62 tfs.append( SpatialTransform( self.HP.INPUT_DIM, patch_center_dist_from_border= center_dist_from_border, do_elastic_deform=self.HP.DAUG_ELASTIC_DEFORM, alpha=(90., 120.), sigma=(9., 11.), do_rotation=self.HP.DAUG_ROTATE, angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8), angle_z=(-0.8, 0.8), do_scale=True, scale=(0.9, 1.5), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True)) if self.HP.DAUG_RESAMPLE: tfs.append(ResampleTransform(zoom_range=(0.5, 1))) if self.HP.DAUG_NOISE: tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05))) if self.HP.DAUG_MIRROR: tfs.append(MirrorTransform()) if self.HP.DAUG_FLIP_PEAKS: tfs.append(FlipVectorAxisTransform()) #num_cached_per_queue 1 or 2 does not really make a difference batch_gen = MultiThreadedAugmenter(batch_gen, Compose(tfs), num_processes=num_processes, num_cached_per_queue=1, seeds=None) return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)