def get_train_transform(patch_size): tr_transforms = [] tr_transforms.append( SpatialTransform_2( patch_size, [i // 2 for i in patch_size], do_elastic_deform=True, deformation_scale=(0, 0.05), do_rotation=True, angle_x=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi), angle_y=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi), angle_z=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=-2.34, border_mode_seg='constant', border_cval_seg=0)) tr_transforms.append(MirrorTransform(axes=(0, 1, 2))) tr_transforms.append( BrightnessMultiplicativeTransform((0.7, 1.5), per_channel=True, p_per_sample=0.15)) tr_transforms.append( GammaTransform(gamma_range=(0.5, 2), invert_image=True, per_channel=True, p_per_sample=0.15)) tr_transforms.append( GaussianNoiseTransform(noise_variance=(0, 0.15), p_per_sample=0.15)) tr_transforms = Compose(tr_transforms) return tr_transforms
def get_train_transform(patch_size): tr_transforms = [] tr_transforms.append( SpatialTransform_2( None, [i // 2 for i in patch_size], do_elastic_deform=False, deformation_scale=(0, 0.25), do_rotation=True, angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), do_scale=False, scale=(0.8, 1.2), border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_seg=1, order_data=3, random_crop=False, p_el_per_sample=0.3, p_rot_per_sample=0.3, p_scale_per_sample=0.3)) tr_transforms.append( RndTransform(MirrorTransform(axes=(0, 1, 2)), prob=0.3)) tr_transforms = Compose(transforms=tr_transforms) return tr_transforms
def test_random_distributions_2D(self): ### test whether all 4 possible mirrorings occur in approximately equal frquencies in 2D batch_gen = BasicDataLoader((self.x_2D, self.y_2D), self.batch_size, number_of_threads_in_multithreaded=None) batch_gen = SingleThreadedAugmenter(batch_gen, MirrorTransform((0, 1))) counts = np.zeros(shape=(4,)) for b in range(self.num_batches): batch = next(batch_gen) for ix in range(self.batch_size): if (batch['data'][ix, :, :, :] == self.cam_left).all(): counts[0] = counts[0] + 1 elif (batch['data'][ix, :, :, :] == self.cam_updown).all(): counts[1] = counts[1] + 1 elif (batch['data'][ix, :, :, :] == self.cam_updown_left).all(): counts[2] = counts[2] + 1 elif (batch['data'][ix, :, :, :] == self.cam).all(): counts[3] = counts[3] + 1 self.assertTrue([1 if (2200 < c < 2800) else 0 for c in counts] == [1]*4, "2D Images were not mirrored along " "all axes with equal probability. " "This may also indicate that " "mirroring is not working")
def test_image_pipeline_and_pin_memory(self): ''' This just should not crash :return: ''' try: import torch except ImportError: '''dont test if torch is not installed''' return tr_transforms = [] tr_transforms.append(MirrorTransform()) tr_transforms.append( TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5)) tr_transforms.append(NumpyToTensor(keys='data', cast_to='float')) composed = Compose(tr_transforms) dl = self.dl_images mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True) for _ in range(50): res = mt.next() assert isinstance(res['data'], torch.Tensor) assert res['data'].is_pinned() # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent # the success of the test but it does not look pretty) sleep(2)
def get_transforms( patch_shape=(256, 320), other_transforms=None, random_crop=False): """ Initializes the transforms for training. Args: patch_shape: other_transforms: List of transforms that you would like to add (optional). Defaults to None. random_crop (boolean): whether or not you want to random crop or center crop. Currently, the Transformed3DGenerator only supports random cropping. Transformed2DGenerator supports both random_crop = True and False. """ ndim = len(patch_shape) spatial_transform = SpatialTransform(patch_shape, do_elastic_deform=True, alpha=(0., 1500.), sigma=(30., 80.), do_rotation=True, angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 2.), border_mode_data="nearest", border_cval_data=0, order_data=1, random_crop=random_crop, p_el_per_sample=0.1, p_scale_per_sample=0.1, p_rot_per_sample=0.1) mirror_transform = MirrorTransform(axes=(0, 1)) transforms_list = [spatial_transform, mirror_transform] if other_transforms is not None: transforms_list = transforms_list + other_transforms composed = Compose(transforms_list) return composed
def get_train_transform(patch_size): """ data augmentation for training data, inspired by: https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py :param patch_size: shape of network's input :return list of transformations """ train_transforms = [] def rad(deg): return (-deg / 360 * 2 * np.pi, deg / 360 * 2 * np.pi) train_transforms.append( SpatialTransform_2( patch_size, (10, 10, 10), do_elastic_deform=True, deformation_scale=(0, 0.25), do_rotation=True, angle_z=rad(15), angle_x=(0, 0), angle_y=(0, 0), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_seg=1, random_crop=False, p_el_per_sample=0.2, p_rot_per_sample=0.2, p_scale_per_sample=0.2, )) train_transforms.append(MirrorTransform(axes=(0, 1))) train_transforms.append( BrightnessMultiplicativeTransform((0.7, 1.5), per_channel=True, p_per_sample=0.2)) train_transforms.append( GammaTransform(gamma_range=(0.2, 1.0), invert_image=False, per_channel=False, p_per_sample=0.2)) train_transforms.append( GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.2)) train_transforms.append( GaussianBlurTransform(blur_sigma=(0.2, 1.0), different_sigma_per_channel=False, p_per_channel=0.0, p_per_sample=0.2)) return Compose(train_transforms)
def get_train_transform(patch_size): # we now create a list of transforms. These are not necessarily the best transforms to use for BraTS, this is just # to showcase some things tr_transforms = [] # the first thing we want to run is the SpatialTransform. It reduces the size of our data to patch_size and thus # also reduces the computational cost of all subsequent operations. All subsequent operations do not modify the # shape and do not transform spatially, so no border artifacts will be introduced # Here we use the new SpatialTransform_2 which uses a new way of parameterizing elastic_deform # We use all spatial transformations with a probability of 0.2 per sample. This means that 1 - (1 - 0.1) ** 3 = 27% # of samples will be augmented, the rest will just be cropped tr_transforms.append( SpatialTransform_2( patch_size, [i // 2 for i in patch_size], do_elastic_deform=True, deformation_scale=(0, 0.25), do_rotation=True, angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_seg=1, order_data=3, random_crop=True, p_el_per_sample=0.1, p_rot_per_sample=0.1, p_scale_per_sample=0.1)) # now we mirror along all axes tr_transforms.append(MirrorTransform(axes=(0, 1, 2))) # gamma transform. This is a nonlinear transformation of intensity values # (https://en.wikipedia.org/wiki/Gamma_correction) tr_transforms.append( GammaTransform(gamma_range=(0.5, 2), invert_image=False, per_channel=True, p_per_sample=0.15)) # we can also invert the image, apply the transform and then invert back tr_transforms.append( GammaTransform(gamma_range=(0.5, 2), invert_image=True, per_channel=True, p_per_sample=0.15)) # Gaussian Noise tr_transforms.append( GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.15)) # now we compose these transforms together tr_transforms = Compose(tr_transforms) return tr_transforms
def _augment_data(self, batch_generator, type=None): if self.Config.DATA_AUGMENTATION: num_processes = 16 # 2D: 8 is a bit faster than 16 # num_processes = 8 else: num_processes = 6 tfs = [] #transforms if self.Config.NORMALIZE_DATA: tfs.append(ZeroMeanUnitVarianceTransform(per_channel=self.Config.NORMALIZE_PER_CHANNEL)) if self.Config.DATA_AUGMENTATION: if type == "train": # scale: inverted: 0.5 -> bigger; 2 -> smaller # patch_center_dist_from_border: if 144/2=72 -> always exactly centered; otherwise a bit off center (brain can get off image and will be cut then) if self.Config.DAUG_SCALE: center_dist_from_border = int(self.Config.INPUT_DIM[0] / 2.) - 10 # (144,144) -> 62 tfs.append(SpatialTransform(self.Config.INPUT_DIM, patch_center_dist_from_border=center_dist_from_border, do_elastic_deform=self.Config.DAUG_ELASTIC_DEFORM, alpha=(90., 120.), sigma=(9., 11.), do_rotation=self.Config.DAUG_ROTATE, angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8), angle_z=(-0.8, 0.8), do_scale=True, scale=(0.9, 1.5), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=0.2, p_rot_per_sample=0.2, p_scale_per_sample=0.2)) if self.Config.DAUG_RESAMPLE: tfs.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), p_per_sample=0.2)) if self.Config.DAUG_NOISE: tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.2)) if self.Config.DAUG_MIRROR: tfs.append(MirrorTransform()) if self.Config.DAUG_FLIP_PEAKS: tfs.append(FlipVectorAxisTransform()) tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float")) #num_cached_per_queue 1 or 2 does not really make a difference batch_gen = MultiThreadedAugmenter(batch_generator, Compose(tfs), num_processes=num_processes, num_cached_per_queue=1, seeds=None, pin_memory=True) return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
def test_segmentations_2D(self): ### test whether segmentations are mirrored coherently with images batch_gen = BasicDataLoader((self.x_2D, self.y_2D), self.batch_size, number_of_threads_in_multithreaded=None) batch_gen = SingleThreadedAugmenter(batch_gen, MirrorTransform((0, 1))) equivalent = True for b in range(self.num_batches): batch = next(batch_gen) for ix in range(self.batch_size): if (batch['data'][ix] != batch['seg'][ix]).all(): equivalent = False self.assertTrue(equivalent, "2D images and seg were not mirrored in the same way (they should though because " "seg needs to match the corresponding data")
def create_data_gen_pipeline(cf, cities=None, data_split='train', do_aug=True, random=True, n_batches=None): """ create mutli-threaded train/val/test batch generation and augmentation pipeline. :param cities: list of strings or None :param patient_data: dictionary containing one dictionary per patient in the train/test subset :param test_pids: (optional) list of test patient ids, calls the test generator. :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing) :param random: bool, whether to draw random batches or go through data linearly :return: multithreaded_generator """ data_gen = BatchGenerator(cities=cities, batch_size=cf.batch_size, data_dir=cf.data_dir, label_density=cf.label_density, data_split=data_split, resolution=cf.resolution, gt_instances=cf.gt_instances, n_batches=n_batches, random=random) my_transforms = [] if do_aug: mirror_transform = MirrorTransform(axes=(3,)) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[-2:], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop'], border_mode_data=cf.da_kwargs['border_mode_data'], border_mode_seg=cf.da_kwargs['border_mode_seg'], border_cval_seg=cf.da_kwargs['border_cval_seg']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[-2:])) my_transforms.append(GammaTransform(cf.da_kwargs['gamma_range'], invert_image=False, per_channel=True, retain_stats=cf.da_kwargs['gamma_retain_stats'], p_per_sample=cf.da_kwargs['p_gamma'])) my_transforms.append(AddLossMask(cf.ignore_label)) if cf.label_switches is not None: my_transforms.append(StochasticLabelSwitches(cf.name2trainId, cf.label_switches)) all_transforms = Compose(my_transforms) multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) return multithreaded_generator
def combine_transform(): ''' 组合变换:对比度+镜像 :return: ''' my_transforms = [] # 对比度变换 brightness_transform = ContrastAugmentationTransform((0.3, 3.), preserve_range=True) my_transforms.append(brightness_transform) # 镜像变换 mirror_transform = MirrorTransform(axes=(2, 3)) my_transforms.append(mirror_transform) all_transform = Compose(my_transforms) batchgen = my_data_loader.DataLoader(camera(), 4) multithreaded_generator = MultiThreadedAugmenter(batchgen, all_transform, 4, 2) # 显示转换效果 my_data_loader.plot_batch(multithreaded_generator.__next__())
def test_image_pipeline(self): ''' This just should not crash :return: ''' tr_transforms = [] tr_transforms.append(MirrorTransform()) tr_transforms.append( TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5)) composed = Compose(tr_transforms) dl = self.dl_images mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, False) for _ in range(50): res = mt.next() # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent # the success of the test but it does not look pretty) sleep(2)
def get_valid_transform(patch_size): """ data augmentation for validation data inspired by: https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py :param patch_size: shape of network's input :return list of transformations """ train_transforms = [] train_transforms.append( SpatialTransform_2(patch_size, patch_size, do_elastic_deform=False, deformation_scale=(0, 0), do_rotation=False, angle_x=(0, 0), angle_y=(0, 0), angle_z=(0, 0), do_scale=False, scale=(1.0, 1.0), border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_seg=1, order_data=3, random_crop=True, p_el_per_sample=0.1, p_rot_per_sample=0.1, p_scale_per_sample=0.1)) train_transforms.append(MirrorTransform(axes=(0, 1))) return Compose(train_transforms)
def _augment_data(self, batch_generator, type=None): if self.Config.DATA_AUGMENTATION: num_processes = 15 # 15 is a bit faster than 8 on cluster # num_processes = multiprocessing.cpu_count() # on cluster: gives all cores, not only assigned cores else: num_processes = 6 tfs = [] if self.Config.NORMALIZE_DATA: tfs.append(ZeroMeanUnitVarianceTransform(per_channel=self.Config.NORMALIZE_PER_CHANNEL)) if self.Config.SPATIAL_TRANSFORM == "SpatialTransformPeaks": SpatialTransformUsed = SpatialTransformPeaks elif self.Config.SPATIAL_TRANSFORM == "SpatialTransformCustom": SpatialTransformUsed = SpatialTransformCustom else: SpatialTransformUsed = SpatialTransform if self.Config.DATA_AUGMENTATION: if type == "train": # patch_center_dist_from_border: # if 144/2=72 -> always exactly centered; otherwise a bit off center # (brain can get off image and will be cut then) if self.Config.DAUG_SCALE: if self.Config.INPUT_RESCALING: source_mm = 2 # for bb target_mm = float(self.Config.RESOLUTION[:-2]) scale_factor = target_mm / source_mm scale = (scale_factor, scale_factor) else: scale = (0.9, 1.5) if self.Config.PAD_TO_SQUARE: patch_size = self.Config.INPUT_DIM else: patch_size = None # keeps dimensions of the data # spatial transform automatically crops/pads to correct size center_dist_from_border = int(self.Config.INPUT_DIM[0] / 2.) - 10 # (144,144) -> 62 tfs.append(SpatialTransformUsed(patch_size, patch_center_dist_from_border=center_dist_from_border, do_elastic_deform=self.Config.DAUG_ELASTIC_DEFORM, alpha=self.Config.DAUG_ALPHA, sigma=self.Config.DAUG_SIGMA, do_rotation=self.Config.DAUG_ROTATE, angle_x=self.Config.DAUG_ROTATE_ANGLE, angle_y=self.Config.DAUG_ROTATE_ANGLE, angle_z=self.Config.DAUG_ROTATE_ANGLE, do_scale=True, scale=scale, border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=self.Config.P_SAMP, p_rot_per_sample=self.Config.P_SAMP, p_scale_per_sample=self.Config.P_SAMP)) if self.Config.DAUG_RESAMPLE: tfs.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), p_per_sample=0.2, per_channel=False)) if self.Config.DAUG_RESAMPLE_LEGACY: tfs.append(ResampleTransformLegacy(zoom_range=(0.5, 1))) if self.Config.DAUG_GAUSSIAN_BLUR: tfs.append(GaussianBlurTransform(blur_sigma=self.Config.DAUG_BLUR_SIGMA, different_sigma_per_channel=False, p_per_sample=self.Config.P_SAMP)) if self.Config.DAUG_NOISE: tfs.append(GaussianNoiseTransform(noise_variance=self.Config.DAUG_NOISE_VARIANCE, p_per_sample=self.Config.P_SAMP)) if self.Config.DAUG_MIRROR: tfs.append(MirrorTransform()) if self.Config.DAUG_FLIP_PEAKS: tfs.append(FlipVectorAxisTransform()) tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float")) #num_cached_per_queue 1 or 2 does not really make a difference batch_gen = MultiThreadedAugmenter(batch_generator, Compose(tfs), num_processes=num_processes, num_cached_per_queue=1, seeds=None, pin_memory=True) return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
def run(self, img_data, seg_data): # Define label for segmentation for segmentation augmentation if self.seg_augmentation: seg_label = "seg" else: seg_label = "class" # Create a parser for the batchgenerators module data_generator = DataParser(img_data, seg_data, seg_label) # Initialize empty transform list transforms = [] # Add mirror augmentation if self.mirror: aug_mirror = MirrorTransform(axes=self.config_mirror_axes) transforms.append(aug_mirror) # Add contrast augmentation if self.contrast: aug_contrast = ContrastAugmentationTransform( self.config_contrast_range, preserve_range=self.config_contrast_preserverange, per_channel=self.coloraug_per_channel, p_per_sample=self.config_p_per_sample) transforms.append(aug_contrast) # Add brightness augmentation if self.brightness: aug_brightness = BrightnessMultiplicativeTransform( self.config_brightness_range, per_channel=self.coloraug_per_channel, p_per_sample=self.config_p_per_sample) transforms.append(aug_brightness) # Add gamma augmentation if self.gamma: aug_gamma = GammaTransform(self.config_gamma_range, invert_image=False, per_channel=self.coloraug_per_channel, retain_stats=True, p_per_sample=self.config_p_per_sample) transforms.append(aug_gamma) # Add gaussian noise augmentation if self.gaussian_noise: aug_gaussian_noise = GaussianNoiseTransform( self.config_gaussian_noise_range, p_per_sample=self.config_p_per_sample) transforms.append(aug_gaussian_noise) # Add spatial transformations as augmentation # (rotation, scaling, elastic deformation) if self.rotations or self.scaling or self.elastic_deform or \ self.cropping: # Identify patch shape (full image or cropping) if self.cropping: patch_shape = self.cropping_patch_shape else: patch_shape = img_data[0].shape[0:-1] # Assembling the spatial transformation aug_spatial_transform = SpatialTransform( patch_shape, [i // 2 for i in patch_shape], do_elastic_deform=self.elastic_deform, alpha=self.config_elastic_deform_alpha, sigma=self.config_elastic_deform_sigma, do_rotation=self.rotations, angle_x=self.config_rotations_angleX, angle_y=self.config_rotations_angleY, angle_z=self.config_rotations_angleZ, do_scale=self.scaling, scale=self.config_scaling_range, border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_data=3, order_seg=0, p_el_per_sample=self.config_p_per_sample, p_rot_per_sample=self.config_p_per_sample, p_scale_per_sample=self.config_p_per_sample, random_crop=self.cropping) # Append spatial transformation to transformation list transforms.append(aug_spatial_transform) # Compose the batchgenerators transforms all_transforms = Compose(transforms) # Assemble transforms into a augmentation generator augmentation_generator = SingleThreadedAugmenter( data_generator, all_transforms) # Perform the data augmentation x times (x = cycles) aug_img_data = None aug_seg_data = None for i in range(0, self.cycles): # Run the computation process for the data augmentations augmentation = next(augmentation_generator) # Access augmentated data from the batchgenerators data structure if aug_img_data is None and aug_seg_data is None: aug_img_data = augmentation["data"] aug_seg_data = augmentation[seg_label] # Concatenate the new data augmentated data with the cached data else: aug_img_data = np.concatenate( (augmentation["data"], aug_img_data), axis=0) aug_seg_data = np.concatenate( (augmentation[seg_label], aug_seg_data), axis=0) # Transform data from channel-first back to channel-last structure # Data structure channel-first 3D: (batch, channel, x, y, z) # Data structure channel-last 3D: (batch, x, y, z, channel) aug_img_data = np.moveaxis(aug_img_data, 1, -1) aug_seg_data = np.moveaxis(aug_seg_data, 1, -1) # Return augmentated image and segmentation data return aug_img_data, aug_seg_data
def main(args): ######################################## # # # DEFINE THE HYPERPARAMETERS # # # ######################################## # load settings from config file config_file = args.get("config_file") config_handler = ConfigHandler() config_dict = config_handler(config_file) # some are changed rarely and given manually if required train_size = args.get("train_size") val_size = args.get("val_size") margin = args.get("margin") optimizer = args.get("optimizer") if optimizer != "SGD" and optimizer != "Adam": ValueError("Invalid optimizer") elif optimizer == "Adam": optimizer_cls = torch.optim.Adam else: optimizer_cls = torch.optim.SGD params = Parameters( fixed_params={ "model": config_dict["model"], "training": { **config_dict["training"], "optimizer_cls": optimizer_cls, **config_dict["optimizer"], "criterions": { "FocalLoss": losses.FocalLoss(), "SmoothL1Loss": losses.SmoothL1Loss() }, # "criterions": {"FocalMSELoss": losses.FocalMSELoss(), # "SmoothL1Loss": losses.SmoothL1Loss()}, # "lr_sched_cls": ReduceLROnPlateauCallbackPyTorch, # "lr_sched_params": {"verbose": True}, "lr_sched_cls": None, "lr_sched_params": {}, "metrics": {} } }) ######################################## # # # DEFINE THE AUGMENTATIONS # # # ######################################## my_transforms = [] mirror_transform = MirrorTransform(axes=(1, 2)) my_transforms.append(mirror_transform) crop_size = config_dict["data"]["crop_size"] img_shape = config_dict["data"]["img_shape"] shape_limit = config_dict["data"]["shape_limit"] if (crop_size is not None and crop_size[0] == crop_size[1]) or \ (img_shape is not None and len(img_shape) > 1 and img_shape[0] == img_shape[1]): rot_transform = Rot90Transform(axes=(0, 1), p_per_sample=0.5) my_transforms.append(rot_transform) else: rot_transform = Rot90Transform(axes=(0, 1), num_rot=(0, 2), p_per_sample=0.5) my_transforms.append(rot_transform) # apply a more extended augmentation (if desiered) if "ext_aug" in config_dict["data"].keys() and \ config_dict["data"]["ext_aug"] is not None and \ config_dict["data"]["ext_aug"]: if crop_size is not None: size = [crop_size[0] + 25, crop_size[1] + 25] elif img_shape is not None: size = [img_shape[0] + 5, img_shape[1] + 5] elif shape_limit is not None: size = [shape_limit[0] + 5, shape_limit[1] + 5] else: raise KeyError("Crop size or image shape requried!") if crop_size is not None: spatial_transforms = SpatialTransform([size[0] - 25, size[1] - 25], np.asarray(size) // 2, do_elastic_deform=False, do_rotation=True, angle_x=(0, 0.01 * np.pi), do_scale=True, scale=(0.9, 1.1), random_crop=True, border_mode_data="mirror", border_mode_seg="mirror") my_transforms.append(spatial_transforms) elif img_shape is not None or shape_limit is not None: spatial_transforms = SpatialTransform( [size[0] - 5, size[1] - 5], np.asarray(size) // 2, do_elastic_deform=False, do_rotation=False, #angle_x=(0, 0.01 * np.pi), do_scale=True, scale=(0.9, 1.1), random_crop=True, border_mode_data="constant", border_mode_seg="nearest") my_transforms.append(spatial_transforms) # bbox generation bb_transform = ConvertSegToBB(dim=2, margin=margin) my_transforms.append(bb_transform) transforms = Compose(my_transforms) ######################################## # # # DEFINE THE DATASETS and MANAGER # # # ######################################## # paths to csv files containing labels (and other information) csv_calc_train = '/home/temp/moriz/data/' \ 'calc_case_description_train_set.csv' csv_mass_train = '/home/temp/moriz/data/' \ 'mass_case_description_train_set.csv' # path to data directory ddsm_dir = '/home/temp/moriz/data/CBIS-DDSM/' # path to data directory inbreast_dir = '/images/Mammography/INbreast/AllDICOMs/' # paths to csv files containing labels (and other information) xls_file = '/images/Mammography/INbreast/INbreast.xls' # determine class and load function if config_dict["data"]["dataset_type"] == "INbreast": dataset_cls = CacheINbreastDataset data_dir = inbreast_dir csv_file = None if config_dict["data"]["level"] == "crops": load_fn = inbreast_utils.load_pos_crops elif config_dict["data"]["level"] == "images": load_fn = inbreast_utils.load_sample elif config_dict["data"]["level"] == "both": #TODO: fix load_fn = inbreast_utils.load_sample_and_crops else: raise TypeError("Level required!") elif config_dict["data"]["dataset_type"] == "DDSM": data_dir = ddsm_dir if config_dict["data"]["level"] == "crops": load_fn = ddsm_utils.load_pos_crops elif config_dict["data"]["level"] == "images": load_fn = ddsm_utils.load_sample elif config_dict["data"]["level"] == "images+": load_fn = ddsm_utils.load_sample_with_crops else: raise TypeError("Level required!") if config_dict["data"]["type"] == "mass": csv_file = csv_mass_train elif config_dict["data"]["type"] == "calc": csv_file = csv_calc_train elif config_dict["data"]["type"] == "both": raise NotImplementedError("Todo") else: raise TypeError("Unknown lesion type!") if "mode" in config_dict["data"].keys(): if config_dict["data"]["mode"] == "lazy": dataset_cls = LazyDDSMDataset if config_dict["data"]["level"] == "crops": load_fn = ddsm_utils.load_single_pos_crops elif config_dict["data"]["mode"] == "cache": dataset_cls = CacheDDSMDataset else: raise TypeError("Unsupported loading mode!") else: dataset_cls = CacheDDSMDataset else: raise TypeError("Dataset is not supported!") dataset_train_dict = { 'data_path': data_dir, 'xls_file': xls_file, 'csv_file': csv_file, 'load_fn': load_fn, 'num_elements': config_dict["debug"]["n_train"], **config_dict["data"] } dataset_val_dict = { 'data_path': data_dir, 'xls_file': xls_file, 'csv_file': csv_file, 'load_fn': load_fn, 'num_elements': config_dict["debug"]["n_val"], **config_dict["data"] } datamgr_train_dict = { 'batch_size': params.nested_get("batch_size"), 'n_process_augmentation': 4, 'transforms': transforms, 'sampler_cls': RandomSampler, 'data_loader_cls': BaseDataLoader } datamgr_val_dict = { 'batch_size': params.nested_get("batch_size"), 'n_process_augmentation': 4, 'transforms': transforms, 'sampler_cls': SequentialSampler, 'data_loader_cls': BaseDataLoader } ######################################## # # # INITIALIZE THE ACTUAL EXPERIMENT # # # ######################################## checkpoint_path = config_dict["checkpoint_path"]["path"] # if "checkpoint_path" in args and args["checkpoint_path"] is not None: # checkpoint_path = args.get("checkpoint_path") experiment = \ RetinaNetExperiment(params, RetinaNet, name = config_dict["logging"]["name"], save_path = checkpoint_path, dataset_cls=dataset_cls, dataset_train_kwargs=dataset_train_dict, datamgr_train_kwargs=datamgr_train_dict, dataset_val_kwargs=dataset_val_dict, datamgr_val_kwargs=datamgr_val_dict, optim_builder=create_optims_default_pytorch, gpu_ids=list(range(args.get('gpus'))), val_score_key="val_FocalLoss", val_score_mode="lowest", checkpoint_freq=2) ######################################## # # # LOGGING DEFINITION AND CONFIGURATION # # # ######################################## logger_kwargs = config_dict["logging"] # setup initial logging log_file = os.path.join(experiment.save_path, 'logger.log') logging.basicConfig(level=logging.INFO, handlers=[ TrixiHandler(PytorchVisdomLogger, **config_dict["logging"]), logging.StreamHandler(), logging.FileHandler(log_file) ]) logger = logging.getLogger("RetinaNet Logger") with open(experiment.save_path + "/config.yml", 'w') as file: yaml.dump(config_dict, file) ######################################## # # # LOAD PATHS AND EXECUTE MODEL # # # ######################################## seed = config_dict["data"]["seed"] if "train_size" in config_dict["data"].keys(): train_size = config_dict["data"]["train_size"] if "val_size" in config_dict["data"].keys(): val_size = config_dict["data"]["val_size"] if config_dict["data"]["dataset_type"] == "INbreast": if not config_dict["kfold"]["enable"]: train_paths, _, val_paths = \ inbreast_utils.load_single_set(inbreast_dir, xls_file=xls_file, train_size=train_size, val_size=val_size, type=config_dict["data"]["type"], random_state=seed) if img_shape is not None or crop_size is not None: experiment.run(train_paths, val_paths) else: experiment.run(train_paths, None) else: paths = inbreast_utils.get_paths(inbreast_dir, xls_file=xls_file, type=config_dict["data"]["type"]) if "splits" in config_dict["kfold"].keys(): num_splits = config_dict["kfold"]["splits"] else: num_splits = 5 experiment.kfold(paths, num_splits=num_splits, random_seed=seed, dataset_type="INbreast") else: train_paths, val_paths, _ = \ ddsm_utils.load_single_set(ddsm_dir, csv_file=csv_file, train_size=train_size, val_size=None, random_state=seed) if img_shape is not None or crop_size is not None: experiment.run(train_paths, val_paths) else: experiment.run(train_paths, None)
def get_batches(self, batch_size=128, type=None, subjects=None, num_batches=None): data = subjects seg = [] #6 -> >30GB RAM if self.HP.DATA_AUGMENTATION: num_processes = 8 # 6 is a bit faster than 16 else: num_processes = 6 nr_of_samples = len(subjects) * self.HP.INPUT_DIM[0] if num_batches is None: num_batches_multithr = int( nr_of_samples / batch_size / num_processes) #number of batches for exactly one epoch else: num_batches_multithr = int(num_batches / num_processes) if self.HP.TYPE == "combined": # Simple with .npy -> just a little bit faster than Nifti (<10%) and f1 not better => use Nifti # batch_gen = SlicesBatchGeneratorRandomNpyImg_fusion((data, seg), batch_size=batch_size) batch_gen = SlicesBatchGeneratorRandomNpyImg_fusion( (data, seg), batch_size=batch_size) else: batch_gen = SlicesBatchGeneratorRandomNiftiImg( (data, seg), batch_size=batch_size) # batch_gen = SlicesBatchGeneratorRandomNiftiImg_5slices((data, seg), batch_size=batch_size) batch_gen.HP = self.HP tfs = [] #transforms if self.HP.NORMALIZE_DATA: tfs.append( ZeroMeanUnitVarianceTransform( per_channel=self.HP.NORMALIZE_PER_CHANNEL)) if self.HP.DATASET == "Schizo" and self.HP.RESOLUTION == "2mm": tfs.append(PadToMultipleTransform(16)) if self.HP.DATA_AUGMENTATION: if type == "train": # scale: inverted: 0.5 -> bigger; 2 -> smaller # patch_center_dist_from_border: if 144/2=72 -> always exactly centered; otherwise a bit off center (brain can get off image and will be cut then) if self.HP.DAUG_SCALE: center_dist_from_border = int( self.HP.INPUT_DIM[0] / 2.) - 10 # (144,144) -> 62 tfs.append( SpatialTransform( self.HP.INPUT_DIM, patch_center_dist_from_border= center_dist_from_border, do_elastic_deform=self.HP.DAUG_ELASTIC_DEFORM, alpha=(90., 120.), sigma=(9., 11.), do_rotation=self.HP.DAUG_ROTATE, angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8), angle_z=(-0.8, 0.8), do_scale=True, scale=(0.9, 1.5), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True)) if self.HP.DAUG_RESAMPLE: tfs.append(ResampleTransform(zoom_range=(0.5, 1))) if self.HP.DAUG_NOISE: tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05))) if self.HP.DAUG_MIRROR: tfs.append(MirrorTransform()) if self.HP.DAUG_FLIP_PEAKS: tfs.append(FlipVectorAxisTransform()) #num_cached_per_queue 1 or 2 does not really make a difference batch_gen = MultiThreadedAugmenter(batch_gen, Compose(tfs), num_processes=num_processes, num_cached_per_queue=1, seeds=None) return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
def _make_training_transforms(self): if self.no_data_augmentation: print("No data augmentation will be performed during training!") return [] patch_size = self.patch_size[::-1] # (x, y, z) order rot_angle_x = self.training_augmentation_args.get('angle_x', 15) rot_angle_y = self.training_augmentation_args.get('angle_y', 15) rot_angle_z = self.training_augmentation_args.get('angle_z', 15) p_per_sample = self.training_augmentation_args.get( 'p_per_sample', 0.15) train_transforms = [ SpatialTransform_2( patch_size, patch_size // 2, do_elastic_deform=self.training_augmentation_args.get( 'do_elastic_deform', True), deformation_scale=self.training_augmentation_args.get( 'deformation_scale', (0, 0.25)), do_rotation=self.training_augmentation_args.get( 'do_rotation', True), angle_x=(-rot_angle_x / 360. * 2 * np.pi, rot_angle_x / 360. * 2 * np.pi), angle_y=(-rot_angle_y / 360. * 2 * np.pi, rot_angle_y / 360. * 2 * np.pi), angle_z=(-rot_angle_z / 360. * 2 * np.pi, rot_angle_z / 360. * 2 * np.pi), do_scale=self.training_augmentation_args.get('do_scale', True), scale=self.training_augmentation_args.get( 'scale', (0.75, 1.25)), border_mode_data='nearest', border_cval_data=0, order_data=3, # border_mode_seg='nearest', border_cval_seg=0, # order_seg=0, random_crop=False, p_el_per_sample=self.training_augmentation_args.get( 'p_el_per_sample', 0.5), p_rot_per_sample=self.training_augmentation_args.get( 'p_rot_per_sample', 0.5), p_scale_per_sample=self.training_augmentation_args.get( 'p_scale_per_sample', 0.5)) ] if self.training_augmentation_args.get("do_mirror", False): train_transforms.append(MirrorTransform(axes=(0, 1, 2))) train_transforms.append( BrightnessMultiplicativeTransform( self.training_augmentation_args.get('brightness_range', (0.7, 1.5)), per_channel=True, p_per_sample=p_per_sample)) train_transforms.append( GaussianNoiseTransform( noise_variance=self.training_augmentation_args.get( 'gaussian_noise_variance', (0, 0.05)), p_per_sample=p_per_sample)) train_transforms.append( GammaTransform(gamma_range=self.training_augmentation_args.get( 'gamma_range', (0.5, 2)), invert_image=False, per_channel=True, p_per_sample=p_per_sample)) print("train_transforms\n", train_transforms) return train_transforms
train_transform = transforms.Compose([ # mt_transforms.CenterCrop2D((200, 200)), mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0), sigma_range=(3.5, 4.0), p=0.3), mt_transforms.RandomAffine(degrees=4.6, scale=(0.98, 1.02), translate=(0.03, 0.03)), mt_transforms.RandomTensorChannelShift((-0.10, 0.10)), mt_transforms.ToTensor() # mt_transforms.NormalizeInstance(), ]) gamma_t = GammaTransform(data_key="img", gamma_range=(0.1, 10)) mirror_t = MirrorTransform(data_key="img", label_key="seg") spatial_t = SpatialTransform(patch_size=(8, 8, 8), data_key="img", label_key="seg") gauss_noise_t = GaussianNoiseTransform(data_key="img", noise_variance=(0, 1)) zoom_t = ZoomTransform(zoom_factors=2, data_key="img") def show_basic(x, gt, info=None): if info is not None: print("Test for " + info) print("img size: {}, max: {}, min: {}, avg: {}.".format(
def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, pin_memory=True, seeds_train=None, seeds_val=None, regions=None): assert params.get( 'mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) patch_size_spatial = patch_size[1:] else: patch_size_spatial = patch_size # Set order_data=0 and order_seg=0 for some more speed for cascade??? tr_transforms.append( SpatialTransform(patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis"))) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations" ) and not None and params.get( "cascade_do_cascade_augmentations"): # Remove the following transforms to remove cascade DA ?? tr_transforms.append( ApplyRandomBinaryOperatorTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get( "cascade_random_binary_transform_p"), key="data", strel_size=params.get( "cascade_random_binary_transform_size"))) tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get( "cascade_remove_conn_comp_max_size_percent_threshold"), dont_do_if_covers_more_than_X_percent=params.get( "cascade_remove_conn_comp_fill_with_other_class_p"))) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) # from batchgenerators.dataloading import SingleThreadedAugmenter # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) # import IPython;IPython.embed() batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): val_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val
}, encode_block=ResBlockStack, encode_kwargs_fn=encode_kwargs_fn, decode_block=ResBlock).cuda() patch_size = (160, 160, 80) optimizer = optim.Adam(model.parameters(), lr=1e-4) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=30) tr_transform = Compose([ GammaTransform((0.9, 1.1)), ContrastAugmentationTransform((0.9, 1.1)), BrightnessMultiplicativeTransform((0.9, 1.1)), MirrorTransform(axes=[0]), SpatialTransform_2( patch_size, (90, 90, 50), random_crop=True, do_elastic_deform=True, deformation_scale=(0, 0.05), do_rotation=True, angle_x=(-0.1 * np.pi, 0.1 * np.pi), angle_y=(0, 0), angle_z=(0, 0), do_scale=True, scale=(0.9, 1.1), border_mode_data='constant', ), RandomCropTransform(crop_size=patch_size),
def get_moreDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None, use_nondetMultiThreadedAugmenter: bool = False): assert params.get( 'mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): ignore_axes = (0, ) tr_transforms.append(Convert3DTo2DTransform()) patch_size_spatial = patch_size[1:] else: patch_size_spatial = patch_size ignore_axes = None tr_transforms.append( SpatialTransform(patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), p_rot_per_axis=params.get("rotation_p_per_axis"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis"))) if params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color # channel gets in the way tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) tr_transforms.append( GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5)) tr_transforms.append( BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15)) if params.get("do_additive_brightness"): tr_transforms.append( BrightnessTransform( params.get("additive_brightness_mu"), params.get("additive_brightness_sigma"), True, p_per_sample=params.get("additive_brightness_p_per_sample"), p_per_channel=params.get("additive_brightness_p_per_channel"))) tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15)) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.25, ignore_axes=ignore_axes)) tr_transforms.append( GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=0.1)) # inverted gamma if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror") or params.get("mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get( "cascade_do_cascade_augmentations") is not None and params.get( "cascade_do_cascade_augmentations"): if params.get("cascade_random_binary_transform_p") > 0: tr_transforms.append( ApplyRandomBinaryOperatorTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get( "cascade_random_binary_transform_p"), key="data", strel_size=params.get( "cascade_random_binary_transform_size"), p_per_label=params.get( "cascade_random_binary_transform_p_per_label"))) if params.get("cascade_remove_conn_comp_p") > 0: tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get( "cascade_remove_conn_comp_max_size_percent_threshold" ), dont_do_if_covers_more_than_X_percent=params.get( "cascade_remove_conn_comp_fill_with_other_class_p") )) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None tr_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) if use_nondetMultiThreadedAugmenter: if NonDetMultiThreadedAugmenter is None: raise RuntimeError( 'NonDetMultiThreadedAugmenter is not yet available') batchgenerator_train = NonDetMultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) else: batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) # import IPython;IPython.embed() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): val_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None val_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) if use_nondetMultiThreadedAugmenter: if NonDetMultiThreadedAugmenter is None: raise RuntimeError( 'NonDetMultiThreadedAugmenter is not yet available') batchgenerator_val = NonDetMultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) else: batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms) return batchgenerator_train, batchgenerator_val
do_rotation=True, angle_z=(0, 2 * np.pi), # 旋转 do_scale=True, scale=(0.3, 3.), # 缩放 border_mode_data='constant', border_cval_data=0, order_data=1, random_crop=False) my_transforms.append(spatial_transform) GaussianNoise = GaussianNoiseTransform() # 高斯噪声 my_transforms.append(GaussianNoise) GaussianBlur = GaussianBlurTransform() # 高斯模糊 my_transforms.append(GaussianBlur) Brightness = BrightnessTransform(0, 0.2) # 亮度 my_transforms.append(Brightness) brightness_transform = ContrastAugmentationTransform( (0.3, 3.), preserve_range=True) # 对比度 my_transforms.append(brightness_transform) SimulateLowResolution = SimulateLowResolutionTransform() # 低分辨率 my_transforms.append(SimulateLowResolution) Gamma = GammaTransform() # 伽马增强 my_transforms.append(Gamma) mirror_transform = MirrorTransform(axes=(0, 1)) # 镜像 my_transforms.append(mirror_transform) all_transforms = Compose(my_transforms) multithreaded_generator = MultiThreadedAugmenter(batchgen, all_transforms, 1, 2) t = multithreaded_generator.next() plot_batch(t)
angle_z=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi), do_scale=True, scale=(0.9, 1.02), border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_seg=1, order_data=3, random_crop=False, p_el_per_sample=0.1, p_rot_per_sample=0.1, p_scale_per_sample=0.1) my_transforms.append(spatial_transform) my_transforms.append(MirrorTransform(axes=(0, 1, 2))) my_transforms.append( GammaTransform(gamma_range=(0.7, 1.), invert_image=False, per_channel=True, p_per_sample=0.1)) all_transforms = Compose(my_transforms) train_loader = SingleThreadedAugmenter( batchgen, all_transforms ) #data loader for training, applying on the fly transformation # add other data loaders test_loader = torch.utils.data.DataLoader( dataset_test,
def get_train_transforms(self) -> List[AbstractTransform]: # used for transpost and rot90 matching_axes = np.array( [sum([i == j for j in self.patch_size]) for i in self.patch_size]) valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0]) tr_transforms = [] if self.data_aug_params['selected_seg_channels'] is not None: tr_transforms.append( SegChannelSelectionTransform( self.data_aug_params['selected_seg_channels'])) if self.do_dummy_2D_aug: ignore_axes = (0, ) tr_transforms.append(Convert3DTo2DTransform()) patch_size_spatial = self.patch_size[1:] else: patch_size_spatial = self.patch_size ignore_axes = None tr_transforms.append( SpatialTransform( patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=False, do_rotation=True, angle_x=self.data_aug_params["rotation_x"], angle_y=self.data_aug_params["rotation_y"], angle_z=self.data_aug_params["rotation_z"], p_rot_per_axis=0.5, do_scale=True, scale=self.data_aug_params['scale_range'], border_mode_data="constant", border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=-1, order_seg=1, random_crop=False, p_el_per_sample=0.2, p_scale_per_sample=0.2, p_rot_per_sample=0.4, independent_scale_for_each_axis=True, )) if self.do_dummy_2D_aug: tr_transforms.append(Convert2DTo3DTransform()) if np.any(matching_axes > 1): tr_transforms.append( Rot90Transform((0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5), ) if np.any(matching_axes > 1): tr_transforms.append( TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5)) tr_transforms.append( OneOfTransform([ MedianFilterTransform((2, 8), same_for_each_channel=False, p_per_sample=0.2, p_per_channel=0.5), GaussianBlurTransform((0.3, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5) ])) tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) tr_transforms.append( BrightnessTransform(0, 0.5, per_channel=True, p_per_sample=0.1, p_per_channel=0.5)) tr_transforms.append( OneOfTransform([ ContrastAugmentationTransform(contrast_range=(0.5, 2), preserve_range=True, per_channel=True, data_key='data', p_per_sample=0.2, p_per_channel=0.5), ContrastAugmentationTransform(contrast_range=(0.5, 2), preserve_range=False, per_channel=True, data_key='data', p_per_sample=0.2, p_per_channel=0.5), ])) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.25, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.15, ignore_axes=ignore_axes)) tr_transforms.append( GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1)) tr_transforms.append( GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1)) if self.do_mirroring: tr_transforms.append(MirrorTransform(self.mirror_axes)) tr_transforms.append( BlankRectangleTransform([[max(1, p // 10), p // 3] for p in self.patch_size], rectangle_value=np.mean, num_rectangles=(1, 5), force_square=False, p_per_sample=0.4, p_per_channel=0.5)) tr_transforms.append( BrightnessGradientAdditiveTransform( lambda x, y: np.exp( np.random.uniform(np.log(x[y] // 6), np.log(x[y]))), (-0.5, 1.5), max_strength=lambda x, y: np.random.uniform(-5, -1) if np.random.uniform() < 0.5 else np.random.uniform(1, 5), mean_centered=False, same_for_all_channels=False, p_per_sample=0.3, p_per_channel=0.5)) tr_transforms.append( LocalGammaTransform( lambda x, y: np.exp( np.random.uniform(np.log(x[y] // 6), np.log(x[y]))), (-0.5, 1.5), lambda: np.random.uniform(0.01, 0.8) if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4), same_for_all_channels=False, p_per_sample=0.3, p_per_channel=0.5)) tr_transforms.append( SharpeningTransform(strength=(0.1, 1), same_for_each_channel=False, p_per_sample=0.2, p_per_channel=0.5)) if any(self.use_mask_for_norm.values()): tr_transforms.append( MaskTransform(self.use_mask_for_norm, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if self.data_aug_params["move_last_seg_chanel_to_data"]: all_class_labels = np.arange(1, self.num_classes) tr_transforms.append( MoveSegAsOneHotToData(1, all_class_labels, 'seg', 'data')) if self.data_aug_params["cascade_do_cascade_augmentations"]: tr_transforms.append( ApplyRandomBinaryOperatorTransform(channel_idx=list( range(-len(all_class_labels), 0)), p_per_sample=0.4, key="data", strel_size=(1, 8), p_per_label=1)) tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(all_class_labels), 0)), key="data", p_per_sample=0.2, fill_with_other_class_p=0.15, dont_do_if_covers_more_than_X_percent=0)) tr_transforms.append(RenameTransform('seg', 'target', True)) if self.regions is not None: tr_transforms.append( ConvertSegmentationToRegionsTransform(self.regions, 'target', 'target')) if self.deep_supervision_scales is not None: tr_transforms.append( DownsampleSegForDSTransform2(self.deep_supervision_scales, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) return tr_transforms