def __init__( self, x, # (n_samples, z, y, x, channels) for 3D or (n_samples, y, x, channels) for 2D data y, transform, # the transformations from the batchgenerators API batch_size=32, shuffle=False, seed=None, save_to_dir=None, save_prefix='', save_format='png'): self.x = self._check_and_transform_x(x) self.y = self._check_and_transform_y(y) self.save_to_dir = save_to_dir self.save_prefix = save_prefix self.save_format = save_format super().__init__(x.shape[0], batch_size, shuffle, seed) if not isinstance(transform, Compose): if isinstance(transform, list): self.transform = Compose(transform) else: self.transform = Compose([transform]) else: self.transform = transform # print("NumpyArrayIteratorUpTo3D: self.transform=", self.transform) self.data_loader = self._create_data_loader()
def wrap_transforms(self, dataloader_train, dataloader_val, train_transforms, val_transforms): tr_gen = NonDetMultiThreadedAugmenter(dataloader_train, Compose(train_transforms), self.num_proc_DA, self.num_cached, seeds=None, pin_memory=self.pin_memory) val_gen = NonDetMultiThreadedAugmenter(dataloader_val, Compose(val_transforms), self.num_proc_DA // 2, self.num_cached, seeds=None, pin_memory=self.pin_memory) return tr_gen, val_gen
def get_transforms( patch_shape=(256, 320), other_transforms=None, random_crop=False): """ Initializes the transforms for training. Args: patch_shape: other_transforms: List of transforms that you would like to add (optional). Defaults to None. random_crop (boolean): whether or not you want to random crop or center crop. Currently, the Transformed3DGenerator only supports random cropping. Transformed2DGenerator supports both random_crop = True and False. """ ndim = len(patch_shape) spatial_transform = SpatialTransform(patch_shape, do_elastic_deform=True, alpha=(0., 1500.), sigma=(30., 80.), do_rotation=True, angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 2.), border_mode_data="nearest", border_cval_data=0, order_data=1, random_crop=random_crop, p_el_per_sample=0.1, p_scale_per_sample=0.1, p_rot_per_sample=0.1) mirror_transform = MirrorTransform(axes=(0, 1)) transforms_list = [spatial_transform, mirror_transform] if other_transforms is not None: transforms_list = transforms_list + other_transforms composed = Compose(transforms_list) return composed
def get_train_transform(patch_size): tr_transforms = [] tr_transforms.append( SpatialTransform_2( patch_size, [i // 2 for i in patch_size], do_elastic_deform=True, deformation_scale=(0, 0.05), do_rotation=True, angle_x=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi), angle_y=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi), angle_z=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=-2.34, border_mode_seg='constant', border_cval_seg=0)) tr_transforms.append(MirrorTransform(axes=(0, 1, 2))) tr_transforms.append( BrightnessMultiplicativeTransform((0.7, 1.5), per_channel=True, p_per_sample=0.15)) tr_transforms.append( GammaTransform(gamma_range=(0.5, 2), invert_image=True, per_channel=True, p_per_sample=0.15)) tr_transforms.append( GaussianNoiseTransform(noise_variance=(0, 0.15), p_per_sample=0.15)) tr_transforms = Compose(tr_transforms) return tr_transforms
def create_data_gen_pipeline(cf, patient_data, do_aug=True, **kwargs): """ create mutli-threaded train/val/test batch generation and augmentation pipeline. :param patient_data: dictionary containing one dictionary per patient in the train/test subset. :param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ # create instance of batch generator as first element in pipeline. data_gen = BatchGenerator(cf, patient_data, **kwargs) my_transforms = [] if do_aug: if cf.da_kwargs["mirror"]: mirror_transform = Mirror(axes=cf.da_kwargs['mirror_axes']) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, cf.roi_items, False, cf.class_specific_seg)) all_transforms = Compose(my_transforms) # multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms) multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) return multithreaded_generator
def crop_transform_random(): ''' 从图片的任意位置剪切出一个固定的尺寸 :return: ''' crop_size = (128, 128) batchgen = my_data_loader.DataLoader(camera(), 4) # 随机地从图片上剪切除(128,128)尺寸的图片块 randomCrop = RandomCropTransform(crop_size=crop_size) spatial_transform = SpatialTransform(crop_size, np.array(crop_size) // 2, do_elastic_deform=True, alpha=(0., 1500.), sigma=(30., 50.), do_rotation=True, angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.5, 2), border_mode_data='constant', border_cval_data=0, order_data=1, random_crop=False) multithreaded_generator = MultiThreadedAugmenter( batchgen, Compose([randomCrop, spatial_transform]), 4, 2) my_data_loader.plot_batch(multithreaded_generator.__next__())
def test_image_pipeline_and_pin_memory(self): ''' This just should not crash :return: ''' try: import torch except ImportError: '''dont test if torch is not installed''' return tr_transforms = [] tr_transforms.append(MirrorTransform()) tr_transforms.append( TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5)) tr_transforms.append(NumpyToTensor(keys='data', cast_to='float')) composed = Compose(tr_transforms) dl = self.dl_images mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True) for _ in range(50): res = mt.next() assert isinstance(res['data'], torch.Tensor) assert res['data'].is_pinned() # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent # the success of the test but it does not look pretty) sleep(2)
def get_batches(self, batch_size=128, type=None, subjects=None, num_batches=None): data = type seg = [] num_processes = 1 # 6 is a bit faster than 16 nr_of_samples = len(subjects) * self.HP.INPUT_DIM[0] if num_batches is None: num_batches_multithr = int( nr_of_samples / batch_size / num_processes) #number of batches for exactly one epoch else: num_batches_multithr = int(num_batches / num_processes) batch_gen = SlicesBatchGeneratorPrecomputedBatches( (data, seg), BATCH_SIZE=batch_size, num_batches=num_batches_multithr, seed=None) batch_gen.HP = self.HP batch_gen = MultiThreadedAugmenter(batch_gen, Compose([]), num_processes=num_processes, num_cached_per_queue=1, seeds=None) return batch_gen
def get_batches(self, batch_size=1): data = np.nan_to_num(self.data) # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth)) seg = np.zeros( (self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.NR_OF_CLASSES)).astype(self.HP.LABELS_TYPE) num_processes = 1 # not not use more than 1 if you want to keep original slice order (Threads do return in random order) batch_gen = SlicesBatchGenerator((data, seg), BATCH_SIZE=batch_size) batch_gen.HP = self.HP tfs = [] # transforms if self.HP.NORMALIZE_DATA: tfs.append( ZeroMeanUnitVarianceTransform( per_channel=self.HP.NORMALIZE_PER_CHANNEL)) tfs.append(ReorderSegTransform()) batch_gen = MultiThreadedAugmenter( batch_gen, Compose(tfs), num_processes=num_processes, num_cached_per_queue=2, seeds=None ) # Only use num_processes=1, otherwise global_idx of SlicesBatchGenerator not working return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, x, y, channels)
def spatial_transforms(): ''' 空间变换:变形、缩放、旋转 :return: ''' # 变形+旋转+缩放 spatial_transform = SpatialTransform(camera().shape, np.array(camera().shape) // 2, do_elastic_deform=True, alpha=(0., 1500.), sigma=(30., 50.), do_rotation=True, angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.3, 3.), border_mode_data='constant', border_cval_data=0, order_data=1, random_crop=False) my_transforms = [] my_transforms.append(spatial_transform) all_transforms = Compose(my_transforms) batchgen = my_data_loader.DataLoader(camera(), 4) multithreaded_generator = MultiThreadedAugmenter(batchgen, all_transforms, 4, 2) # 显示转换效果 my_data_loader.plot_batch(multithreaded_generator.__next__())
def _augment_data(Config, batch_generator, type=None): batch_gen = MultiThreadedAugmenter(batch_generator, Compose([]), num_processes=1, num_cached_per_queue=1, seeds=None) return batch_gen
def random_transform(): ''' 随机地对某些批次的数据进行变换 :return: ''' spatial_transform = SpatialTransform(camera().shape, np.array(camera().shape) // 2, do_elastic_deform=True, alpha=(0., 1500.), sigma=(30., 50.), do_rotation=True, angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.3, 3.), border_mode_data='constant', border_cval_data=0, order_data=1, random_crop=False) sometimes_spatial_transform = RndTransform(spatial_transform, prob=0.5) batchgen = my_data_loader.DataLoader(camera(), 4) multithreaded_generator = MultiThreadedAugmenter( batchgen, Compose([sometimes_spatial_transform]), 4, 2) for _ in range(4): my_data_loader.plot_batch(multithreaded_generator.__next__())
def get_transformer(bbox_image_shape = [256, 256, 256], deformation_scale = 0.2): """ :param bbox_image_shape: [256, 256, 256] :param deformation_scale: 扭曲程度,0几乎没形变,0.2形变很大,故0~0.25是合理的 :return: """ tr_transforms = [] # tr_transforms.append(MirrorTransform(axes=(0, 1, 2))) # (这个SpatialTransform_2与SpatialTransform的区别就在这里,SpatialTransform_2提供了有一定限制的扭曲变化,保证图像不会过度扭曲) tr_transforms.append( SpatialTransform_2( patch_size=bbox_image_shape, patch_center_dist_from_border=[i // 2 for i in bbox_image_shape], do_elastic_deform=True, deformation_scale=(deformation_scale, deformation_scale + 0.1), do_rotation=False, angle_x=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), # 随机旋转的角度 angle_y=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), angle_z=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), do_scale=False, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_seg=1, order_data=3, random_crop=False, p_el_per_sample=1.0, p_rot_per_sample=1.0, p_scale_per_sample=1.0 )) # tr_transforms.append( # SpatialTransform( # patch_size=bbox_image.shape, # patch_center_dist_from_border=[i // 2 for i in bbox_image.shape], # do_elastic_deform=True, alpha=(2000., 2100.), sigma=(10., 11.), # do_rotation=False, # angle_x=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), # angle_y=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), # angle_z=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), # do_scale=False, # scale=(0.75, 0.75), # border_mode_data='constant', border_cval_data=0, # border_mode_seg='constant', border_cval_seg=0, # order_seg=1, order_data=3, # random_crop=True, # p_el_per_sample=1.0, p_rot_per_sample=1.0, p_scale_per_sample=1.0 # )) # sigma越小,扭曲越局部(即扭曲的越严重), alpha越大扭曲的越严重 # tr_transforms.append( # SpatialTransform(bbox_image.shape, [i // 2 for i in bbox_image.shape], # do_elastic_deform=True, alpha=(1300., 1500.), sigma=(10., 11.), # do_rotation=False, angle_z=(0, 2 * np.pi), # do_scale=False, scale=(0.3, 0.7), # border_mode_data='constant', border_cval_data=0, order_data=1, # border_mode_seg='constant', border_cval_seg=0, # random_crop=False)) all_transforms = Compose(tr_transforms) return all_transforms
def crop_transform_center_seg(): ''' 从图片的正正中心剪切分割 :return: ''' crop_size = (128, 128) batchgen = my_data_loader.DataLoader(camera(), 4) centerCropSeg = CenterCropSegTransform(output_size=crop_size) multithreaded_generator = MultiThreadedAugmenter(batchgen, Compose([centerCropSeg]), 4, 2) my_data_loader.plot_batch(multithreaded_generator.__next__())
def _augment_data(self, batch_generator, type=None): if self.Config.DATA_AUGMENTATION: num_processes = 16 # 2D: 8 is a bit faster than 16 # num_processes = 8 else: num_processes = 6 tfs = [] #transforms if self.Config.NORMALIZE_DATA: tfs.append(ZeroMeanUnitVarianceTransform(per_channel=self.Config.NORMALIZE_PER_CHANNEL)) if self.Config.DATA_AUGMENTATION: if type == "train": # scale: inverted: 0.5 -> bigger; 2 -> smaller # patch_center_dist_from_border: if 144/2=72 -> always exactly centered; otherwise a bit off center (brain can get off image and will be cut then) if self.Config.DAUG_SCALE: center_dist_from_border = int(self.Config.INPUT_DIM[0] / 2.) - 10 # (144,144) -> 62 tfs.append(SpatialTransform(self.Config.INPUT_DIM, patch_center_dist_from_border=center_dist_from_border, do_elastic_deform=self.Config.DAUG_ELASTIC_DEFORM, alpha=(90., 120.), sigma=(9., 11.), do_rotation=self.Config.DAUG_ROTATE, angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8), angle_z=(-0.8, 0.8), do_scale=True, scale=(0.9, 1.5), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=0.2, p_rot_per_sample=0.2, p_scale_per_sample=0.2)) if self.Config.DAUG_RESAMPLE: tfs.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), p_per_sample=0.2)) if self.Config.DAUG_NOISE: tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.2)) if self.Config.DAUG_MIRROR: tfs.append(MirrorTransform()) if self.Config.DAUG_FLIP_PEAKS: tfs.append(FlipVectorAxisTransform()) tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float")) #num_cached_per_queue 1 or 2 does not really make a difference batch_gen = MultiThreadedAugmenter(batch_generator, Compose(tfs), num_processes=num_processes, num_cached_per_queue=1, seeds=None, pin_memory=True) return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
def create_data_gen_pipeline(cf, cities=None, data_split='train', do_aug=True, random=True, n_batches=None): """ create mutli-threaded train/val/test batch generation and augmentation pipeline. :param cities: list of strings or None :param patient_data: dictionary containing one dictionary per patient in the train/test subset :param test_pids: (optional) list of test patient ids, calls the test generator. :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing) :param random: bool, whether to draw random batches or go through data linearly :return: multithreaded_generator """ data_gen = BatchGenerator(cities=cities, batch_size=cf.batch_size, data_dir=cf.data_dir, label_density=cf.label_density, data_split=data_split, resolution=cf.resolution, gt_instances=cf.gt_instances, n_batches=n_batches, random=random) my_transforms = [] if do_aug: mirror_transform = MirrorTransform(axes=(3,)) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[-2:], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop'], border_mode_data=cf.da_kwargs['border_mode_data'], border_mode_seg=cf.da_kwargs['border_mode_seg'], border_cval_seg=cf.da_kwargs['border_cval_seg']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[-2:])) my_transforms.append(GammaTransform(cf.da_kwargs['gamma_range'], invert_image=False, per_channel=True, retain_stats=cf.da_kwargs['gamma_retain_stats'], p_per_sample=cf.da_kwargs['p_gamma'])) my_transforms.append(AddLossMask(cf.ignore_label)) if cf.label_switches is not None: my_transforms.append(StochasticLabelSwitches(cf.name2trainId, cf.label_switches)) all_transforms = Compose(my_transforms) multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) return multithreaded_generator
def get_batches(self, batch_size=128, type=None, subjects=None, num_batches=None): data = type seg = [] num_processes = 1 # 6 is a bit faster than 16 batch_gen = SlicesBatchGeneratorPrecomputedBatches( (data, seg), batch_size=batch_size) batch_gen.HP = self.HP batch_gen = MultiThreadedAugmenter(batch_gen, Compose([]), num_processes=num_processes, num_cached_per_queue=1, seeds=None) return batch_gen
def create_data_gen_pipeline(patient_data, cf, test_pids=None, do_aug=True): """ create mutli-threaded train/val/test batch generation and augmentation pipeline. :param patient_data: dictionary containing one dictionary per patient in the train/test subset :param test_pids: (optional) list of test patient ids, calls the test generator. :param do_aug: (optional) whether to perform data augmentation (training) or not (validation/testing) :return: multithreaded_generator """ if test_pids is None: data_gen = BatchGenerator(patient_data, batch_size=cf.batch_size, pre_crop_size=cf.pre_crop_size, dim=cf.dim) else: data_gen = TestGenerator(patient_data, batch_size=cf.batch_size, n_batches=None, pre_crop_size=cf.pre_crop_size, test_pids=test_pids, dim=cf.dim) cf.n_workers = 1 my_transforms = [] if do_aug: mirror_transform = Mirror(axes=(2, 3)) my_transforms.append(mirror_transform) spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim], patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'], do_elastic_deform=cf.da_kwargs['do_elastic_deform'], alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'], do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'], angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'], do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'], random_crop=cf.da_kwargs['random_crop']) my_transforms.append(spatial_transform) else: my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim])) my_transforms.append(ConvertSegToOnehotTransform(classes=(0, 1, 2))) my_transforms.append(TransposeChannels()) all_transforms = Compose(my_transforms) multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers)) return multithreaded_generator
def combine_transform(): ''' 组合变换:对比度+镜像 :return: ''' my_transforms = [] # 对比度变换 brightness_transform = ContrastAugmentationTransform((0.3, 3.), preserve_range=True) my_transforms.append(brightness_transform) # 镜像变换 mirror_transform = MirrorTransform(axes=(2, 3)) my_transforms.append(mirror_transform) all_transform = Compose(my_transforms) batchgen = my_data_loader.DataLoader(camera(), 4) multithreaded_generator = MultiThreadedAugmenter(batchgen, all_transform, 4, 2) # 显示转换效果 my_data_loader.plot_batch(multithreaded_generator.__next__())
def test_image_pipeline(self): ''' This just should not crash :return: ''' tr_transforms = [] tr_transforms.append(MirrorTransform()) tr_transforms.append( TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5)) composed = Compose(tr_transforms) dl = self.dl_images mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, False) for _ in range(50): res = mt.next() # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent # the success of the test but it does not look pretty) sleep(2)
def get_train_transform(patch_size): # we now create a list of transforms. These are not necessarily the best transforms to use for BraTS, this is just # to showcase some things tr_transforms = [] # the first thing we want to run is the SpatialTransform. It reduces the size of our data to patch_size and thus # also reduces the computational cost of all subsequent operations. All subsequent operations do not modify the # shape and do not transform spatially, so no border artifacts will be introduced # Here we use the new SpatialTransform_2 which uses a new way of parameterizing elastic_deform # We use all spatial transformations with a probability of 0.2 per sample. This means that 1 - (1 - 0.1) ** 3 = 27% # of samples will be augmented, the rest will just be cropped tr_transforms.append( SpatialTransform_2( patch_size, [i // 2 for i in patch_size], do_elastic_deform=True, deformation_scale=(0, 0.25), do_rotation=True, angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_seg=1, order_data=3, random_crop=True, p_el_per_sample=0.1, p_rot_per_sample=0.1, p_scale_per_sample=0.1)) # now we mirror along all axes tr_transforms.append(MirrorTransform(axes=(0, 1, 2))) # brightness transform for 15% of samples tr_transforms.append( BrightnessMultiplicativeTransform((0.7, 1.5), per_channel=True, p_per_sample=0.15)) # gamma transform. This is a nonlinear transformation of intensity values # (https://en.wikipedia.org/wiki/Gamma_correction) tr_transforms.append( GammaTransform(gamma_range=(0.5, 2), invert_image=False, per_channel=True, p_per_sample=0.15)) # we can also invert the image, apply the transform and then invert back tr_transforms.append( GammaTransform(gamma_range=(0.5, 2), invert_image=True, per_channel=True, p_per_sample=0.15)) # Gaussian Noise tr_transforms.append( GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.15)) # blurring. Some BraTS cases have very blurry modalities. This can simulate more patients with this problem and # thus make the model more robust to it tr_transforms.append( GaussianBlurTransform(blur_sigma=(0.5, 1.5), different_sigma_per_channel=True, p_per_channel=0.5, p_per_sample=0.15)) # now we compose these transforms together tr_transforms = Compose(tr_transforms) return tr_transforms
def _augment_data(self, batch_generator, type=None): if self.Config.DATA_AUGMENTATION: num_processes = 15 # 15 is a bit faster than 8 on cluster # num_processes = multiprocessing.cpu_count() # on cluster: gives all cores, not only assigned cores else: num_processes = 6 tfs = [] if self.Config.NORMALIZE_DATA: tfs.append(ZeroMeanUnitVarianceTransform(per_channel=self.Config.NORMALIZE_PER_CHANNEL)) if self.Config.SPATIAL_TRANSFORM == "SpatialTransformPeaks": SpatialTransformUsed = SpatialTransformPeaks elif self.Config.SPATIAL_TRANSFORM == "SpatialTransformCustom": SpatialTransformUsed = SpatialTransformCustom else: SpatialTransformUsed = SpatialTransform if self.Config.DATA_AUGMENTATION: if type == "train": # patch_center_dist_from_border: # if 144/2=72 -> always exactly centered; otherwise a bit off center # (brain can get off image and will be cut then) if self.Config.DAUG_SCALE: if self.Config.INPUT_RESCALING: source_mm = 2 # for bb target_mm = float(self.Config.RESOLUTION[:-2]) scale_factor = target_mm / source_mm scale = (scale_factor, scale_factor) else: scale = (0.9, 1.5) if self.Config.PAD_TO_SQUARE: patch_size = self.Config.INPUT_DIM else: patch_size = None # keeps dimensions of the data # spatial transform automatically crops/pads to correct size center_dist_from_border = int(self.Config.INPUT_DIM[0] / 2.) - 10 # (144,144) -> 62 tfs.append(SpatialTransformUsed(patch_size, patch_center_dist_from_border=center_dist_from_border, do_elastic_deform=self.Config.DAUG_ELASTIC_DEFORM, alpha=self.Config.DAUG_ALPHA, sigma=self.Config.DAUG_SIGMA, do_rotation=self.Config.DAUG_ROTATE, angle_x=self.Config.DAUG_ROTATE_ANGLE, angle_y=self.Config.DAUG_ROTATE_ANGLE, angle_z=self.Config.DAUG_ROTATE_ANGLE, do_scale=True, scale=scale, border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True, p_el_per_sample=self.Config.P_SAMP, p_rot_per_sample=self.Config.P_SAMP, p_scale_per_sample=self.Config.P_SAMP)) if self.Config.DAUG_RESAMPLE: tfs.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), p_per_sample=0.2, per_channel=False)) if self.Config.DAUG_RESAMPLE_LEGACY: tfs.append(ResampleTransformLegacy(zoom_range=(0.5, 1))) if self.Config.DAUG_GAUSSIAN_BLUR: tfs.append(GaussianBlurTransform(blur_sigma=self.Config.DAUG_BLUR_SIGMA, different_sigma_per_channel=False, p_per_sample=self.Config.P_SAMP)) if self.Config.DAUG_NOISE: tfs.append(GaussianNoiseTransform(noise_variance=self.Config.DAUG_NOISE_VARIANCE, p_per_sample=self.Config.P_SAMP)) if self.Config.DAUG_MIRROR: tfs.append(MirrorTransform()) if self.Config.DAUG_FLIP_PEAKS: tfs.append(FlipVectorAxisTransform()) tfs.append(NumpyToTensor(keys=["data", "seg"], cast_to="float")) #num_cached_per_queue 1 or 2 does not really make a difference batch_gen = MultiThreadedAugmenter(batch_generator, Compose(tfs), num_processes=num_processes, num_cached_per_queue=1, seeds=None, pin_memory=True) return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, pin_memory=True, seeds_train=None, seeds_val=None, regions=None): assert params.get( 'mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) patch_size_spatial = patch_size[1:] else: patch_size_spatial = patch_size # Set order_data=0 and order_seg=0 for some more speed for cascade??? tr_transforms.append( SpatialTransform(patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis"))) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations" ) and not None and params.get( "cascade_do_cascade_augmentations"): # Remove the following transforms to remove cascade DA ?? tr_transforms.append( ApplyRandomBinaryOperatorTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get( "cascade_random_binary_transform_p"), key="data", strel_size=params.get( "cascade_random_binary_transform_size"))) tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get( "cascade_remove_conn_comp_max_size_percent_threshold"), dont_do_if_covers_more_than_X_percent=params.get( "cascade_remove_conn_comp_fill_with_other_class_p"))) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) # from batchgenerators.dataloading import SingleThreadedAugmenter # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) # import IPython;IPython.embed() batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): val_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val
# border_mode_seg='constant', border_cval_seg=0, # order_seg=1, order_data=3, # random_crop=True, # p_el_per_sample=1.0, p_rot_per_sample=1.0, p_scale_per_sample=1.0 # )) # sigma越小,扭曲越局部(即扭曲的越严重), alpha越大扭曲的越严重 # tr_transforms.append( # SpatialTransform(bbox_image.shape, [i // 2 for i in bbox_image.shape], # do_elastic_deform=True, alpha=(1300., 1500.), sigma=(10., 11.), # do_rotation=False, angle_z=(0, 2 * np.pi), # do_scale=False, scale=(0.3, 0.7), # border_mode_data='constant', border_cval_data=0, order_data=1, # border_mode_seg='constant', border_cval_seg=0, # random_crop=False)) all_transforms = Compose(tr_transforms) bbox_image_batch_trans = all_transforms(**bbox_image_batch) # 加入**相当于 # bbox_image_batch_trans1 = all_transforms(**bbox_image_batch) # show3D(np.concatenate([np.squeeze(bbox_image_batch['data'][0],axis=0),np.squeeze(bbox_image_batch_trans['data'][1],axis=0)],axis=1)) # show3Dslice(np.concatenate([np.squeeze(bbox_image_batch['data'][0],axis=0),np.squeeze(bbox_image_batch_trans['data'][1],axis=0)],axis=1)) # show3Dslice(np.concatenate([np.squeeze(bbox_image_batch['seg'][0],axis=0),np.squeeze(bbox_image_batch_trans['seg'][1],axis=0)],axis=1)) show3D( np.concatenate([ np.squeeze(bbox_image_batch['seg'][0], axis=0), np.squeeze(bbox_image_batch_trans['seg'][0], axis=0) ], axis=1)) show3D( np.concatenate([
numpy_to_tensor = NumpyToTensor(['data', 'labels'], cast_to=None) fname = os.path.join(dataset_dir, 'cifar10_training_data.npz') dataset = np.load(fname) cifar_dataset_as_arrays = (dataset['data'], dataset['labels'], dataset['filenames']) print('batch_size', batch_size) print('num_workers', num_workers) print('pin_memory', pin_memory) print('num_epochs', num_epochs) tr_transforms = [ SpatialTransform((32, 32)) ] * 1 # SpatialTransform is computationally expensive and we need some # load on CPU so we just stack 5 of them on top of each other tr_transforms.append(numpy_to_tensor) tr_transforms = Compose(tr_transforms) cifar_dataset = CifarDataset(dataset_dir, train=True, transform=tr_transforms) dl = DataLoaderFromDataset(cifar_dataset, batch_size, num_workers, 1) mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, pin_memory) batches = 0 for _ in mt: batches += 1 assert len(_['data'].shape) == 4 assert batches == len( cifar_dataset
def get_batches(self, batch_size=1): num_processes = 1 # not not use more than 1 if you want to keep original slice order (Threads do return in random order) if self.HP.TYPE == "combined": # Load from Npy file for Fusion data = self.subject seg = [] nr_of_samples = len([self.subject]) * self.HP.INPUT_DIM[0] num_batches = int(nr_of_samples / batch_size / num_processes) batch_gen = SlicesBatchGeneratorNpyImg_fusion( (data, seg), BATCH_SIZE=batch_size, num_batches=num_batches, seed=None) else: # Load Features if self.HP.FEATURES_FILENAME == "12g90g270g": data_img = nib.load( join(self.data_dir, "270g_125mm_peaks.nii.gz")) else: data_img = nib.load( join(self.data_dir, self.HP.FEATURES_FILENAME + ".nii.gz")) data = data_img.get_data() data = np.nan_to_num(data) data = DatasetUtils.scale_input_to_unet_shape( data, self.HP.DATASET, self.HP.RESOLUTION) # data = DatasetUtils.scale_input_to_unet_shape(data, "HCP_32g", "1.25mm") #If we want to test HCP_32g on HighRes net #Load Segmentation if self.use_gt_mask: seg = nib.load( join(self.data_dir, self.HP.LABELS_FILENAME + ".nii.gz")).get_data() if self.HP.LABELS_FILENAME not in [ "bundle_peaks_11_808080", "bundle_peaks_20_808080", "bundle_peaks_808080", "bundle_masks_20_808080", "bundle_masks_72_808080", "bundle_peaks_Part1_808080", "bundle_peaks_Part2_808080", "bundle_peaks_Part3_808080", "bundle_peaks_Part4_808080" ]: if self.HP.DATASET in ["HCP_2mm", "HCP_2.5mm", "HCP_32g"]: # By using "HCP" but lower resolution scale_input_to_unet_shape will automatically downsample the HCP sized seg_mask seg = DatasetUtils.scale_input_to_unet_shape( seg, "HCP", self.HP.RESOLUTION) else: seg = DatasetUtils.scale_input_to_unet_shape( seg, self.HP.DATASET, self.HP.RESOLUTION) else: # Use dummy mask in case we only want to predict on some data (where we do not have Ground Truth)) seg = np.zeros( (self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.INPUT_DIM[0], self.HP.NR_OF_CLASSES)).astype(self.HP.LABELS_TYPE) batch_gen = SlicesBatchGenerator((data, seg), batch_size=batch_size) batch_gen.HP = self.HP tfs = [] # transforms if self.HP.NORMALIZE_DATA: tfs.append(ZeroMeanUnitVarianceTransform(per_channel=False)) if self.HP.TEST_TIME_DAUG: center_dist_from_border = int( self.HP.INPUT_DIM[0] / 2.) - 10 # (144,144) -> 62 tfs.append( SpatialTransform( self.HP.INPUT_DIM, patch_center_dist_from_border=center_dist_from_border, do_elastic_deform=True, alpha=(90., 120.), sigma=(9., 11.), do_rotation=True, angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8), angle_z=(-0.8, 0.8), do_scale=True, scale=(0.9, 1.5), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True)) # tfs.append(ResampleTransform(zoom_range=(0.5, 1))) # tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05))) tfs.append( ContrastAugmentationTransform(contrast_range=(0.7, 1.3), preserve_range=True, per_channel=False)) tfs.append( BrightnessMultiplicativeTransform(multiplier_range=(0.7, 1.3), per_channel=False)) tfs.append(ReorderSegTransform()) batch_gen = MultiThreadedAugmenter( batch_gen, Compose(tfs), num_processes=num_processes, num_cached_per_queue=2, seeds=None ) # Only use num_processes=1, otherwise global_idx of SlicesBatchGenerator not working return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, x, y, channels)
def get_batches(self, batch_size=128, type=None, subjects=None, num_batches=None): data = subjects seg = [] #6 -> >30GB RAM if self.HP.DATA_AUGMENTATION: num_processes = 8 # 6 is a bit faster than 16 else: num_processes = 6 nr_of_samples = len(subjects) * self.HP.INPUT_DIM[0] if num_batches is None: num_batches_multithr = int( nr_of_samples / batch_size / num_processes) #number of batches for exactly one epoch else: num_batches_multithr = int(num_batches / num_processes) if self.HP.TYPE == "combined": # Simple with .npy -> just a little bit faster than Nifti (<10%) and f1 not better => use Nifti # batch_gen = SlicesBatchGeneratorRandomNpyImg_fusion((data, seg), batch_size=batch_size) batch_gen = SlicesBatchGeneratorRandomNpyImg_fusion( (data, seg), batch_size=batch_size) else: batch_gen = SlicesBatchGeneratorRandomNiftiImg( (data, seg), batch_size=batch_size) # batch_gen = SlicesBatchGeneratorRandomNiftiImg_5slices((data, seg), batch_size=batch_size) batch_gen.HP = self.HP tfs = [] #transforms if self.HP.NORMALIZE_DATA: tfs.append( ZeroMeanUnitVarianceTransform( per_channel=self.HP.NORMALIZE_PER_CHANNEL)) if self.HP.DATASET == "Schizo" and self.HP.RESOLUTION == "2mm": tfs.append(PadToMultipleTransform(16)) if self.HP.DATA_AUGMENTATION: if type == "train": # scale: inverted: 0.5 -> bigger; 2 -> smaller # patch_center_dist_from_border: if 144/2=72 -> always exactly centered; otherwise a bit off center (brain can get off image and will be cut then) if self.HP.DAUG_SCALE: center_dist_from_border = int( self.HP.INPUT_DIM[0] / 2.) - 10 # (144,144) -> 62 tfs.append( SpatialTransform( self.HP.INPUT_DIM, patch_center_dist_from_border= center_dist_from_border, do_elastic_deform=self.HP.DAUG_ELASTIC_DEFORM, alpha=(90., 120.), sigma=(9., 11.), do_rotation=self.HP.DAUG_ROTATE, angle_x=(-0.8, 0.8), angle_y=(-0.8, 0.8), angle_z=(-0.8, 0.8), do_scale=True, scale=(0.9, 1.5), border_mode_data='constant', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True)) if self.HP.DAUG_RESAMPLE: tfs.append(ResampleTransform(zoom_range=(0.5, 1))) if self.HP.DAUG_NOISE: tfs.append(GaussianNoiseTransform(noise_variance=(0, 0.05))) if self.HP.DAUG_MIRROR: tfs.append(MirrorTransform()) if self.HP.DAUG_FLIP_PEAKS: tfs.append(FlipVectorAxisTransform()) #num_cached_per_queue 1 or 2 does not really make a difference batch_gen = MultiThreadedAugmenter(batch_gen, Compose(tfs), num_processes=num_processes, num_cached_per_queue=1, seeds=None) return batch_gen # data: (batch_size, channels, x, y), seg: (batch_size, channels, x, y)
def get_transforms(mode="train", n_channels=1, target_size=128, add_resize=False, add_noise=False, mask_type="", batch_size=16, rotate=True, elastic_deform=True, rnd_crop=False, color_augment=True): tranform_list = [] noise_list = [] if mode == "train": tranform_list = [ FillupPadTransform(min_size=(n_channels, target_size + 5, target_size + 5)), ResizeTransform(target_size=(target_size + 1, target_size + 1), order=1, concatenate_list=True), # RandomCropTransform(crop_size=(target_size + 5, target_size + 5)), MirrorTransform(axes=(2, )), ReshapeTransform(new_shape=(1, -1, "h", "w")), SpatialTransform(patch_size=(target_size, target_size), random_crop=rnd_crop, patch_center_dist_from_border=target_size // 2, do_elastic_deform=elastic_deform, alpha=(0., 100.), sigma=(10., 13.), do_rotation=rotate, angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8), scale=(0.9, 1.2), border_mode_data="nearest", border_mode_seg="nearest"), ReshapeTransform(new_shape=(batch_size, -1, "h", "w")) ] if color_augment: tranform_list += [ # BrightnessTransform(mu=0, sigma=0.2), BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1)) ] tranform_list += [ GaussianNoiseTransform(noise_variance=(0., 0.05)), ClipValueRange(min=-1.5, max=1.5), ] noise_list = [] if mask_type == "gaussian": noise_list += [GaussianNoiseTransform(noise_variance=(0., 0.2))] elif mode == "val": tranform_list = [ FillupPadTransform(min_size=(n_channels, target_size + 5, target_size + 5)), ResizeTransform(target_size=(target_size + 1, target_size + 1), order=1, concatenate_list=True), CenterCropTransform(crop_size=(target_size, target_size)), ClipValueRange(min=-1.5, max=1.5), # BrightnessTransform(mu=0, sigma=0.2), # BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1)), CopyTransform({"data": "data_clean"}, copy=True) ] noise_list += [] if add_noise: tranform_list = tranform_list + noise_list tranform_list.append(NumpyToTensor()) return Compose(tranform_list)
# in order to use neptune logging: # export NEPTUNE_API_TOKEN = '...' !!! logging.getLogger().setLevel('INFO') source_files = [__file__] if hparams.config: source_files.append(hparams.config) neptune_logger = NeptuneLogger(project_name=hparams.neptune_project, params=vars(hparams), experiment_name=hparams.experiment_name, tags=[hparams.experiment_name], upload_source_files=source_files) tb_logger = loggers.TensorBoardLogger(hparams.log_dir) transform = Compose([ BrightnessTransform(mu=0.0, sigma=0.3, data_key='data'), GammaTransform(gamma_range=(0.7, 1.3), data_key='data'), ContrastAugmentationTransform(contrast_range=(0.3, 1.7), data_key='data') ]) with open(hparams.train_set, 'r') as keyfile: train_keys = [l.strip() for l in keyfile.readlines()] print(train_keys) with open(hparams.val_set, 'r') as keyfile: val_keys = [l.strip() for l in keyfile.readlines()] print(val_keys) train_ds = MedDataset(hparams.data_path, train_keys, hparams.patches_per_subject, hparams.patch_size,
def get_moreDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None, use_nondetMultiThreadedAugmenter: bool = False): assert params.get( 'mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): ignore_axes = (0, ) tr_transforms.append(Convert3DTo2DTransform()) patch_size_spatial = patch_size[1:] else: patch_size_spatial = patch_size ignore_axes = None tr_transforms.append( SpatialTransform(patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), p_rot_per_axis=params.get("rotation_p_per_axis"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis"))) if params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color # channel gets in the way tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) tr_transforms.append( GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5)) tr_transforms.append( BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15)) if params.get("do_additive_brightness"): tr_transforms.append( BrightnessTransform( params.get("additive_brightness_mu"), params.get("additive_brightness_sigma"), True, p_per_sample=params.get("additive_brightness_p_per_sample"), p_per_channel=params.get("additive_brightness_p_per_channel"))) tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15)) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.25, ignore_axes=ignore_axes)) tr_transforms.append( GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=0.1)) # inverted gamma if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror") or params.get("mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get( "cascade_do_cascade_augmentations") is not None and params.get( "cascade_do_cascade_augmentations"): if params.get("cascade_random_binary_transform_p") > 0: tr_transforms.append( ApplyRandomBinaryOperatorTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get( "cascade_random_binary_transform_p"), key="data", strel_size=params.get( "cascade_random_binary_transform_size"), p_per_label=params.get( "cascade_random_binary_transform_p_per_label"))) if params.get("cascade_remove_conn_comp_p") > 0: tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get( "cascade_remove_conn_comp_max_size_percent_threshold" ), dont_do_if_covers_more_than_X_percent=params.get( "cascade_remove_conn_comp_fill_with_other_class_p") )) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None tr_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) if use_nondetMultiThreadedAugmenter: if NonDetMultiThreadedAugmenter is None: raise RuntimeError( 'NonDetMultiThreadedAugmenter is not yet available') batchgenerator_train = NonDetMultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) else: batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) # import IPython;IPython.embed() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): val_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None val_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) if use_nondetMultiThreadedAugmenter: if NonDetMultiThreadedAugmenter is None: raise RuntimeError( 'NonDetMultiThreadedAugmenter is not yet available') batchgenerator_val = NonDetMultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) else: batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms) return batchgenerator_train, batchgenerator_val