def create_data_gen_train(patient_data_train, BATCH_SIZE, num_classes, num_workers=5, num_cached_per_worker=2, do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.), do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi), do_scale=True, scale_range=(0.75, 1.25), seeds=None): if seeds is None: seeds = [None]*num_workers elif seeds == 'range': seeds = range(num_workers) else: assert len(seeds) == num_workers data_gen_train = BatchGenerator_2D(patient_data_train, BATCH_SIZE, num_batches=None, seed=False, PATCH_SIZE=(352, 352)) tr_transforms = [] tr_transforms.append(Mirror((2, 3))) tr_transforms.append(RndTransform(SpatialTransform((352, 352), list(np.array((352, 352))//2), do_elastic_transform, alpha, sigma, do_rotation, a_x, a_y, a_z, do_scale, scale_range, 'constant', 0, 3, 'constant', 0, 0, random_crop=False), prob=0.67, alternative_transform=RandomCropTransform((352, 352)))) tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), seg_channel=0, output_key='seg_onehot')) tr_composed = Compose(tr_transforms) tr_mt_gen = MultiThreadedAugmenter(data_gen_train, tr_composed, num_workers, num_cached_per_worker, seeds) tr_mt_gen.restart() return tr_mt_gen
def __init__(self, im_4D): self.im_4D = im_4D self.mirror_transform = MirrorTransform(axes=(0, 1, 2)) self.spatial_transform = SpatialTransform(patch_size=None, do_elastic_deform=False, alpha=(0., 1000.), sigma=(10., 13.), do_rotation=True, angle_x=(0, 0), angle_y=(0, 0), angle_z=(0, 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=0, order_data=1, random_crop=False)
def get_transforms(): rot_angle = np.radians(ROTATION_ANGLE) return [ SpatialTransform(patch_size=None, do_elastic_deform=False, do_rotation=True, p_rot_per_sample=ROTATION_P, angle_x=(0, 0), angle_y=(0, 0), angle_z=(-rot_angle, rot_angle), border_mode_data="constant", border_cval_data=0, do_scale=False, random_crop=False), RandomShiftTransform(shift_mu=SHIFT_MU, shift_sigma=SHIFT_SIGMA, p_per_sample=SHIFT_P, p_per_channel=1), RealZoomTransform(max_zoom=MAX_ZOOM, p_per_sample=ZOOM_P) ]
def get_augmentation(patch_size, params=default_3D_augmentation_params, border_val_seg=-1): print(f'patch size after augmentation {patch_size}') tr_transforms = [] tr_transforms.append( SpatialTransform(patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"))) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) tr_transforms = Compose(tr_transforms) return tr_transforms
def create_data_gen_train(patient_data_train, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE, contrast_range=(0.75, 1.5), gamma_range = (0.6, 2), num_workers=5, num_cached_per_worker=3, do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.), do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi), do_scale=True, scale_range=(0.75, 1.25), seeds=None): if seeds is None: seeds = [None]*num_workers elif seeds == 'range': seeds = range(num_workers) else: assert len(seeds) == num_workers data_gen_train = BatchGenerator3D_random_sampling(patient_data_train, BATCH_SIZE, num_batches=None, seed=False, patch_size=(160, 192, 160), convert_labels=True) tr_transforms = [] tr_transforms.append(DataChannelSelectionTransform([0, 1, 2, 3])) tr_transforms.append(GenerateBrainMaskTransform()) tr_transforms.append(MirrorTransform()) tr_transforms.append(SpatialTransform(INPUT_PATCH_SIZE, list(np.array(INPUT_PATCH_SIZE)//2.), do_elastic_deform=do_elastic_transform, alpha=alpha, sigma=sigma, do_rotation=do_rotation, angle_x=a_x, angle_y=a_y, angle_z=a_z, do_scale=do_scale, scale=scale_range, border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True)) tr_transforms.append(BrainMaskAwareStretchZeroOneTransform((-5, 5), True)) tr_transforms.append(ContrastAugmentationTransform(contrast_range, True)) tr_transforms.append(GammaTransform(gamma_range, False)) tr_transforms.append(BrainMaskAwareStretchZeroOneTransform(per_channel=True)) tr_transforms.append(BrightnessTransform(0.0, 0.1, True)) tr_transforms.append(SegChannelSelectionTransform([0])) tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), 0, "seg_onehot")) gen_train = MultiThreadedAugmenter(data_gen_train, Compose(tr_transforms), num_workers, num_cached_per_worker, seeds) gen_train.restart() return gen_train
def get_training_transforms(self): assert self.params.get( 'mirror' ) is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if self.params.get("selected_data_channels"): tr_transforms.append( DataChannelSelectionTransform( self.params.get("selected_data_channels"))) if self.params.get("selected_seg_channels"): tr_transforms.append( SegChannelSelectionTransform( self.params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if self.params.get("dummy_2D", False): ignore_axes = (0, ) tr_transforms.append(Convert3DTo2DTransform()) else: ignore_axes = None tr_transforms.append( SpatialTransform( self._spatial_transform_patch_size, patch_center_dist_from_border=None, do_elastic_deform=self.params.get("do_elastic"), alpha=self.params.get("elastic_deform_alpha"), sigma=self.params.get("elastic_deform_sigma"), do_rotation=self.params.get("do_rotation"), angle_x=self.params.get("rotation_x"), angle_y=self.params.get("rotation_y"), angle_z=self.params.get("rotation_z"), do_scale=self.params.get("do_scaling"), scale=self.params.get("scale_range"), order_data=self.params.get("order_data"), border_mode_data=self.params.get("border_mode_data"), border_cval_data=self.params.get("border_cval_data"), order_seg=self.params.get("order_seg"), border_mode_seg=self.params.get("border_mode_seg"), border_cval_seg=self.params.get("border_cval_seg"), random_crop=self.params.get("random_crop"), p_el_per_sample=self.params.get("p_eldef"), p_scale_per_sample=self.params.get("p_scale"), p_rot_per_sample=self.params.get("p_rot"), independent_scale_for_each_axis=self.params.get( "independent_scale_factor_for_each_axis"), )) if self.params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color # channel gets in the way tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15)) tr_transforms.append( GaussianBlurTransform((0.5, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5), ) tr_transforms.append( BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.3), p_per_sample=0.15)) if self.params.get("do_additive_brightness"): tr_transforms.append( BrightnessTransform( self.params.get("additive_brightness_mu"), self.params.get("additive_brightness_sigma"), True, p_per_sample=self.params.get( "additive_brightness_p_per_sample"), p_per_channel=self.params.get( "additive_brightness_p_per_channel"))) tr_transforms.append( ContrastAugmentationTransform(contrast_range=(0.65, 1.5), p_per_sample=0.15)) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.25, ignore_axes=ignore_axes), ) tr_transforms.append( GammaTransform(self.params.get("gamma_range"), True, True, retain_stats=self.params.get("gamma_retain_stats"), p_per_sample=0.15)) # inverted gamma if self.params.get("do_gamma"): tr_transforms.append( GammaTransform( self.params.get("gamma_range"), False, True, retain_stats=self.params.get("gamma_retain_stats"), p_per_sample=self.params["p_gamma"])) if self.params.get("do_mirror") or self.params.get("mirror"): tr_transforms.append( MirrorTransform(self.params.get("mirror_axes"))) if self.params.get("use_mask_for_norm"): use_mask_for_norm = self.params.get("use_mask_for_norm") tr_transforms.append( MaskTransform(use_mask_for_norm, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) return Compose(tr_transforms)
def get_training_transforms(self): assert self.params.get( 'mirror' ) is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if self.params.get("selected_data_channels"): tr_transforms.append( DataChannelSelectionTransform( self.params.get("selected_data_channels"))) if self.params.get("selected_seg_channels"): tr_transforms.append( SegChannelSelectionTransform( self.params.get("selected_seg_channels"))) if self.params.get("dummy_2D", False): # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append( SpatialTransform( self._spatial_transform_patch_size, patch_center_dist_from_border=None, do_elastic_deform=self.params.get("do_elastic"), alpha=self.params.get("elastic_deform_alpha"), sigma=self.params.get("elastic_deform_sigma"), do_rotation=self.params.get("do_rotation"), angle_x=self.params.get("rotation_x"), angle_y=self.params.get("rotation_y"), angle_z=self.params.get("rotation_z"), do_scale=self.params.get("do_scaling"), scale=self.params.get("scale_range"), order_data=self.params.get("order_data"), border_mode_data=self.params.get("border_mode_data"), border_cval_data=self.params.get("border_cval_data"), order_seg=self.params.get("order_seg"), border_mode_seg=self.params.get("border_mode_seg"), border_cval_seg=self.params.get("border_cval_seg"), random_crop=self.params.get("random_crop"), p_el_per_sample=self.params.get("p_eldef"), p_scale_per_sample=self.params.get("p_scale"), p_rot_per_sample=self.params.get("p_rot"), independent_scale_for_each_axis=self.params.get( "independent_scale_factor_for_each_axis"), )) if self.params.get("dummy_2D", False): tr_transforms.append(Convert2DTo3DTransform()) if self.params.get("do_gamma", False): tr_transforms.append( GammaTransform( self.params.get("gamma_range"), False, True, retain_stats=self.params.get("gamma_retain_stats"), p_per_sample=self.params["p_gamma"])) if self.params.get("do_mirror", False): tr_transforms.append( MirrorTransform(self.params.get("mirror_axes"))) if self.params.get("use_mask_for_norm"): use_mask_for_norm = self.params.get("use_mask_for_norm") tr_transforms.append( MaskTransform(use_mask_for_norm, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) return Compose(tr_transforms)
def get_insaneDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True): assert params.get( 'mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): ignore_axes = (0, ) tr_transforms.append(Convert3DTo2DTransform()) else: ignore_axes = None tr_transforms.append( SpatialTransform(patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis"))) if params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color # channel gets in the way tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15)) tr_transforms.append( GaussianBlurTransform((0.5, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5)) tr_transforms.append( BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3), p_per_sample=0.15)) tr_transforms.append( ContrastAugmentationTransform(contrast_range=(0.65, 1.5), p_per_sample=0.15)) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.25, ignore_axes=ignore_axes)) tr_transforms.append( GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=0.15)) # inverted gamma if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror") or params.get("mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations" ) and not None and params.get( "cascade_do_cascade_augmentations"): if params.get("cascade_random_binary_transform_p") > 0: tr_transforms.append( ApplyRandomBinaryOperatorTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get( "cascade_random_binary_transform_p"), key="data", strel_size=params.get( "cascade_random_binary_transform_size"))) if params.get("cascade_remove_conn_comp_p") > 0: tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get( "cascade_remove_conn_comp_max_size_percent_threshold" ), dont_do_if_covers_more_than_X_percent=params.get( "cascade_remove_conn_comp_fill_with_other_class_p") )) tr_transforms.append(RenameTransform('seg', 'target', True)) if deep_supervision_scales is not None: if soft_ds: assert classes is not None tr_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): val_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if deep_supervision_scales is not None: if soft_ds: assert classes is not None val_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val
def get_transforms(mode="train", n_channels=1, target_size=128, add_resize=False, add_noise=False, mask_type="", batch_size=16, rotate=True, elastic_deform=True, rnd_crop=False, color_augment=True): tranform_list = [] noise_list = [] if mode == "train": tranform_list = [ FillupPadTransform(min_size=(n_channels, target_size + 5, target_size + 5)), ResizeTransform(target_size=(target_size + 1, target_size + 1), order=1, concatenate_list=True), # RandomCropTransform(crop_size=(target_size + 5, target_size + 5)), MirrorTransform(axes=(2, )), ReshapeTransform(new_shape=(1, -1, "h", "w")), SpatialTransform(patch_size=(target_size, target_size), random_crop=rnd_crop, patch_center_dist_from_border=target_size // 2, do_elastic_deform=elastic_deform, alpha=(0., 100.), sigma=(10., 13.), do_rotation=rotate, angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8), scale=(0.9, 1.2), border_mode_data="nearest", border_mode_seg="nearest"), ReshapeTransform(new_shape=(batch_size, -1, "h", "w")) ] if color_augment: tranform_list += [ # BrightnessTransform(mu=0, sigma=0.2), BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1)) ] tranform_list += [ GaussianNoiseTransform(noise_variance=(0., 0.05)), ClipValueRange(min=-1.5, max=1.5), ] noise_list = [] if mask_type == "gaussian": noise_list += [GaussianNoiseTransform(noise_variance=(0., 0.2))] elif mode == "val": tranform_list = [ FillupPadTransform(min_size=(n_channels, target_size + 5, target_size + 5)), ResizeTransform(target_size=(target_size + 1, target_size + 1), order=1, concatenate_list=True), CenterCropTransform(crop_size=(target_size, target_size)), ClipValueRange(min=-1.5, max=1.5), # BrightnessTransform(mu=0, sigma=0.2), # BrightnessMultiplicativeTransform(multiplier_range=(0.95, 1.1)), CopyTransform({"data": "data_clean"}, copy=True) ] noise_list += [] if add_noise: tranform_list = tranform_list + noise_list tranform_list.append(NumpyToTensor()) return Compose(tranform_list)
def run(self, img_data, seg_data): # Create a parser for the batchgenerators module data_generator = DataParser(img_data, seg_data) # Initialize empty transform list transforms = [] # Add mirror augmentation if self.mirror: aug_mirror = MirrorTransform(axes=self.config_mirror_axes) transforms.append(aug_mirror) # Add contrast augmentation if self.contrast: aug_contrast = ContrastAugmentationTransform( self.config_contrast_range, preserve_range=True, per_channel=True, p_per_sample=self.config_p_per_sample) transforms.append(aug_contrast) # Add brightness augmentation if self.brightness: aug_brightness = BrightnessMultiplicativeTransform( self.config_brightness_range, per_channel=True, p_per_sample=self.config_p_per_sample) transforms.append(aug_brightness) # Add gamma augmentation if self.gamma: aug_gamma = GammaTransform(self.config_gamma_range, invert_image=False, per_channel=True, retain_stats=True, p_per_sample=self.config_p_per_sample) transforms.append(aug_gamma) # Add gaussian noise augmentation if self.gaussian_noise: aug_gaussian_noise = GaussianNoiseTransform( self.config_gaussian_noise_range, p_per_sample=self.config_p_per_sample) transforms.append(aug_gaussian_noise) # Add spatial transformations as augmentation # (rotation, scaling, elastic deformation) if self.rotations or self.scaling or self.elastic_deform or \ self.cropping: # Identify patch shape (full image or cropping) if self.cropping: patch_shape = self.cropping_patch_shape else: patch_shape = img_data[0].shape[0:-1] # Assembling the spatial transformation aug_spatial_transform = SpatialTransform( patch_shape, [i // 2 for i in patch_shape], do_elastic_deform=self.elastic_deform, alpha=self.config_elastic_deform_alpha, sigma=self.config_elastic_deform_sigma, do_rotation=self.rotations, angle_x=self.config_rotations_angleX, angle_y=self.config_rotations_angleY, angle_z=self.config_rotations_angleZ, do_scale=self.scaling, scale=self.config_scaling_range, border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_data=3, order_seg=0, p_el_per_sample=self.config_p_per_sample, p_rot_per_sample=self.config_p_per_sample, p_scale_per_sample=self.config_p_per_sample, random_crop=self.cropping) # Append spatial transformation to transformation list transforms.append(aug_spatial_transform) # Compose the batchgenerators transforms all_transforms = Compose(transforms) # Assemble transforms into a augmentation generator augmentation_generator = SingleThreadedAugmenter( data_generator, all_transforms) # Perform the data augmentation x times (x = cycles) aug_img_data = None aug_seg_data = None for i in range(0, self.cycles): # Run the computation process for the data augmentations augmentation = next(augmentation_generator) # Access augmentated data from the batchgenerators data structure if aug_img_data is None and aug_seg_data is None: aug_img_data = augmentation["data"] aug_seg_data = augmentation["seg"] # Concatenate the new data augmentated data with the cached data else: aug_img_data = np.concatenate( (augmentation["data"], aug_img_data), axis=0) aug_seg_data = np.concatenate( (augmentation["seg"], aug_seg_data), axis=0) # Transform data from channel-first back to channel-last structure # Data structure channel-first 3D: (batch, channel, x, y, z) # Data structure channel-last 3D: (batch, x, y, z, channel) aug_img_data = np.moveaxis(aug_img_data, 1, -1) aug_seg_data = np.moveaxis(aug_seg_data, 1, -1) # Return augmentated image and segmentation data return aug_img_data, aug_seg_data
def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, pin_memory=True, seeds_train=None, seeds_val=None, regions=None): assert params.get('mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append(SpatialTransform( patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis") )) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get("mask_was_used_for_normalization") tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations") and not None and params.get( "cascade_do_cascade_augmentations"): tr_transforms.append(ApplyRandomBinaryOperatorTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get("cascade_random_binary_transform_p"), key="data", strel_size=params.get("cascade_random_binary_transform_size"))) tr_transforms.append(RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"), dont_do_if_covers_more_than_X_percent=params.get("cascade_remove_conn_comp_fill_with_other_class_p"))) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) # from batchgenerators.dataloading import SingleThreadedAugmenter # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) # import IPython;IPython.embed() batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val
dl = DummyDL(num_threads_in_mt=3) mt = MultiThreadedAugmenter(dl, None, 3, 1, None) # ignore this code. this is work in progress from BraTS2018.dataset_loading.load_dataset import load_dataset from meddec.dataloading.dataset_loading import DataLoader3D import matplotlib.pyplot as plt import os dataset = load_dataset( os.path.join(os.environ['BraTS_2018_BASE'], "BraTS2018")) dl = DataLoader3D(dataset, (128, 128, 128), (64, 64, 64), 2) #tr = GaussianBlurTransform(3) tr = SpatialTransform((128, 128, 128), (64, 64, 64), False, do_rotation=False, do_scale=True, scale=(0.6, 0.60000001)) from time import time num_batches = 10 num_threads = 8 mt = MultiThreadedAugmenter(dl, tr, num_threads, 2) # warm up warum_up_times_old = [] for _ in range(6): a = time() b = next(mt) warum_up_times_old.append(time() - a)
def create_data_gen_train(patient_data_train, BATCH_SIZE, num_classes, patch_size, num_workers=5, num_cached_per_worker=2, do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.), do_rotation=False, a_x=(0., 2 * np.pi), a_y=(0., 2 * np.pi), a_z=(0., 2 * np.pi), do_scale=True, scale_range=(0.75, 1.25), seeds=None): if seeds is None: seeds = [None] * num_workers elif seeds == 'range': seeds = range(num_workers) else: assert len(seeds) == num_workers data_gen_train = BatchGenerator(patient_data_train, BATCH_SIZE, num_batches=None, seed=False, PATCH_SIZE=(10, 352, 352)) # train transforms tr_transforms = [] tr_transforms.append(MotionAugmentationTransform(0.1, 0, 20)) tr_transforms.append(MirrorTransform((3, 4))) tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append( RndTransform(SpatialTransform(patch_size[1:], 112, do_elastic_transform, alpha, sigma, do_rotation, a_x, a_y, a_z, do_scale, scale_range, 'constant', 0, 3, 'constant', 0, 0, random_crop=False), prob=0.67, alternative_transform=RandomCropTransform( patch_size[1:]))) tr_transforms.append(Convert2DTo3DTransform(patch_size)) tr_transforms.append( RndTransform(GammaTransform((0.85, 1.3), False), prob=0.5)) tr_transforms.append( RndTransform(GammaTransform((0.85, 1.3), True), prob=0.5)) tr_transforms.append(CutOffOutliersTransform(0.3, 99.7, True)) tr_transforms.append(ZeroMeanUnitVarianceTransform(True)) tr_transforms.append( ConvertSegToOnehotTransform(range(num_classes), 0, 'seg_onehot')) tr_composed = Compose(tr_transforms) tr_mt_gen = MultiThreadedAugmenter(data_gen_train, tr_composed, num_workers, num_cached_per_worker, seeds) tr_mt_gen.restart() return tr_mt_gen
def get_default_augmentation(dataloader_train, dataloader_val=None, params=None, patch_size=None, border_val_seg=-1, pin_memory=True, seeds_train=None, seeds_val=None, regions=None): assert params.get('mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] assert params is not None, "augmentation params expect to be not None" if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append(SpatialTransform( patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis") )) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) batchgenerator_train.restart() if dataloader_val is None: return batchgenerator_train, None val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) batchgenerator_val.restart() return batchgenerator_train, batchgenerator_val
dataset_train = dset.get_subset(idx_train) dataset_val = dset.get_subset(idx_val) dataset_test = TorchvisionClassificationDataset("mnist", train=False, img_shape=(224, 224), ) ##################### # Augmentation # ##################### base_transforms = [ZeroMeanUnitVarianceTransform(), ] train_transforms = [SpatialTransform(patch_size=(200, 200), random_crop=False, ), ] ##################### # Datamanagers # ##################### manager_train = BaseDataManager(dataset_train, params.nested_get("batch_size"), transforms=Compose( base_transforms + train_transforms), sampler_cls=RandomSampler, n_process_augmentation=n_process_augmentation) manager_val = BaseDataManager(dataset_val, params.nested_get("batch_size"), transforms=Compose(base_transforms), sampler_cls=SequentialSampler,
def get_insaneDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None): assert params.get( 'mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] # 'patch_size': array([288, 320]), # 'border_val_seg': -1, # 'seeds_train': None, # 'seeds_val': None, # 'order_seg': 1, # 'order_data': 3, # 'deep_supervision_scales': [[1, 1, 1], # [1.0, 0.5, 0.5], # [1.0, 0.25, 0.25], # [0.5, 0.125, 0.125], # [0.5, 0.0625, 0.0625]], # 'soft_ds': False, # 'classes': None, # 'pin_memory': True, # 'regions': None # params # {'selected_data_channels': None, # 'selected_seg_channels': [0], # 'do_elastic': True, # 'elastic_deform_alpha': (0.0, 300.0), # 'elastic_deform_sigma': (9.0, 15.0), # 'p_eldef': 0.1, # 'do_scaling': True, # 'scale_range': (0.65, 1.6), # 'independent_scale_factor_for_each_axis': True, # 'p_independent_scale_per_axis': 0.3, # 'p_scale': 0.3, # 'do_rotation': True, # 'rotation_x': (-3.141592653589793, 3.141592653589793), # 'rotation_y': (-0.5235987755982988, 0.5235987755982988), # 'rotation_z': (-0.5235987755982988, 0.5235987755982988), # 'rotation_p_per_axis': 1, # 'p_rot': 0.7, # 'random_crop': False, # 'random_crop_dist_to_border': None, # 'do_gamma': True, # 'gamma_retain_stats': True, # 'gamma_range': (0.5, 1.6), # 'p_gamma': 0.3, # 'do_mirror': True, # 'mirror_axes': (0, 1, 2), # 'dummy_2D': True, # 'mask_was_used_for_normalization': OrderedDict([(0, False)]), # 'border_mode_data': 'constant', # 'all_segmentation_labels': None, # 'move_last_seg_chanel_to_data': False, # 'cascade_do_cascade_augmentations': False, # 'cascade_random_binary_transform_p': 0.4, # 'cascade_random_binary_transform_p_per_label': 1, # 'cascade_random_binary_transform_size': (1, 8), # 'cascade_remove_conn_comp_p': 0.2, # 'cascade_remove_conn_comp_max_size_percent_threshold': 0.15, # 'cascade_remove_conn_comp_fill_with_other_class_p': 0.0, # 'do_additive_brightness': True, # 'additive_brightness_p_per_sample': 0.3, # 'additive_brightness_p_per_channel': 1, # 'additive_brightness_mu': 0, # 'additive_brightness_sigma': 0.2, # 'num_threads': 12, # 'num_cached_per_thread': 1, # 'patch_size_for_spatialtransform': array([288, 320])} # selected_data_channels is None if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) # selected_seg_channels is [0] if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! # dummy_2D is True if params.get("dummy_2D") is not None and params.get("dummy_2D"): ignore_axes = (0, ) tr_transforms.append(Convert3DTo2DTransform()) else: ignore_axes = None tr_transforms.append( SpatialTransform(patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis"), p_independent_scale_per_axis=params.get( "p_independent_scale_per_axis"))) if params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color # channel gets in the way tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15)) tr_transforms.append( GaussianBlurTransform((0.5, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5)) tr_transforms.append( BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3), p_per_sample=0.15)) tr_transforms.append( ContrastAugmentationTransform(contrast_range=(0.65, 1.5), p_per_sample=0.15)) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.25, ignore_axes=ignore_axes)) tr_transforms.append( GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=0.15)) # inverted gamma # do_additive_brightness is True if params.get("do_additive_brightness"): tr_transforms.append( BrightnessTransform( params.get("additive_brightness_mu"), params.get("additive_brightness_sigma"), True, p_per_sample=params.get("additive_brightness_p_per_sample"), p_per_channel=params.get("additive_brightness_p_per_channel"))) # do_gamma is True if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) # do_mirror is True if params.get("do_mirror") or params.get("mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) # mask_was_used_for_normalization is OrderedDict([(0, False)]), if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) # move_last_seg_chanel_to_data is False if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations" ) and not None and params.get( "cascade_do_cascade_augmentations"): if params.get("cascade_random_binary_transform_p") > 0: tr_transforms.append( ApplyRandomBinaryOperatorTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get( "cascade_random_binary_transform_p"), key="data", strel_size=params.get( "cascade_random_binary_transform_size"))) if params.get("cascade_remove_conn_comp_p") > 0: tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get( "cascade_remove_conn_comp_max_size_percent_threshold" ), dont_do_if_covers_more_than_X_percent=params.get( "cascade_remove_conn_comp_fill_with_other_class_p") )) tr_transforms.append(RenameTransform('seg', 'target', True)) # regions is None if regions is not None: tr_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) # deep_supervision_scales is a not None if deep_supervision_scales is not None: # soft_ds is False if soft_ds: assert classes is not None tr_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) # ======================================================== val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) # selected_data_channels is None if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) # selected_seg_channels is [0] if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # move_last_seg_chanel_to_data is False if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): val_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) # regions is None if regions is not None: val_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) # deep_supervision_scales is not None if deep_supervision_scales is not None: # soft_ds is False if soft_ds: assert classes is not None val_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val
def Transforms(patch_size, params=default_3D_augmentation_params, border_val_seg=-1): tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform(params.get("selected_data_channels"), data_key="data")) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append( SpatialTransform(patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"))) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get( "advanced_pyramid_augmentations") and not None and params.get( "advanced_pyramid_augmentations"): tr_transforms.append( ApplyRandomBinaryOperatorTransform(channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=0.4, key="data", strel_size=(1, 8))) tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=0.2, fill_with_other_class_p=0.0, dont_do_if_covers_more_than_X_percent=0.15)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) return tr_transforms
SpatialTransform(patch_size, patch_center_dist_from_border=None, extra_label_keys=extra_label_keys, do_translate=False, trans_max_shifts={ 'z': 2, 'y': 2, 'x': 2 }, trans_const_channel=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), p_rot_per_axis=params.get("rotation_p_per_axis"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis")))
def get_arteries_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None, use_nondetMultiThreadedAugmenter: bool = False): assert params.get( 'mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! # if params.get("dummy_2D") is not None and params.get("dummy_2D"): # ignore_axes = (0,) # tr_transforms.append(Convert3DTo2DTransform()) # else: # ignore_axes = None tr_transforms.append( SpatialTransform(patch_size, patch_center_dist_from_border=None, do_elastic_deform=False, do_rotation=False, do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=False, p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis"))) if params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) if params.get("do_mirror") or params.get("mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) tr_transforms.append(RemoveLabelTransform(-1, 0)) # if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): # tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) # if params.get("cascade_do_cascade_augmentations") is not None and params.get( # "cascade_do_cascade_augmentations"): # if params.get("cascade_random_binary_transform_p") > 0: # tr_transforms.append(ApplyRandomBinaryOperatorTransform( # channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), # p_per_sample=params.get("cascade_random_binary_transform_p"), # key="data", # strel_size=params.get("cascade_random_binary_transform_size"), # p_per_label=params.get("cascade_random_binary_transform_p_per_label"))) # if params.get("cascade_remove_conn_comp_p") > 0: # tr_transforms.append( # RemoveRandomConnectedComponentFromOneHotEncodingTransform( # channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), # key="data", # p_per_sample=params.get("cascade_remove_conn_comp_p"), # fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"), # dont_do_if_covers_more_than_X_percent=params.get( # "cascade_remove_conn_comp_fill_with_other_class_p"))) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) # if deep_supervision_scales is not None: # if soft_ds: # assert classes is not None # tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) # else: # tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', # output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) # import IPython;IPython.embed() batchgenerator_train.restart() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): # val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) # if deep_supervision_scales is not None: # if soft_ds: # assert classes is not None # val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) # else: # val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', # output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms) batchgenerator_val.restart() return batchgenerator_train, batchgenerator_val
def get_default_augmentation_withEDT(dataloader_train, dataloader_val, patch_size, idx_of_edts, params=default_3D_augmentation_params, border_val_seg=-1, pin_memory=True, seeds_train=None, seeds_val=None): tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append( SpatialTransform(patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"))) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) """ ############################################################## ############################################################## Here we insert moving the EDT to a different key so that it does not get intensity transformed ############################################################## ############################################################## """ tr_transforms.append( AppendChannelsTransform("data", "bound", idx_of_edts, remove_from_input=True)) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get( "advanced_pyramid_augmentations") and not None and params.get( "advanced_pyramid_augmentations"): tr_transforms.append( ApplyRandomBinaryOperatorTransform(channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=0.4, key="data", strel_size=(1, 8))) tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=0.2, fill_with_other_class_p=0.0, dont_do_if_covers_more_than_X_percent=0.15)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) """ ############################################################## ############################################################## Here we insert moving the EDT to a different key ############################################################## ############################################################## """ val_transforms.append( AppendChannelsTransform("data", "bound", idx_of_edts, remove_from_input=True)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): val_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) val_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val
def run_experiment(cp: str, test=True) -> str: """ Run classification experiment on patches Imports moved inside because of logging setups Parameters ---------- ch : str path to config file test : bool test best model on test set Returns ------- str path to experiment folder """ # setup config ch = ConfigHandlerPyTorchDelira(cp) ch = feature_map_params(ch) if 'mixed_precision' not in ch or ch['mixed_precision'] is None: ch['mixed_precision'] = True if 'debug_delira' in ch and ch['debug_delira'] is not None: delira.set_debug_mode(ch['debug_delira']) print("Debug mode active: settings n_process_augmentation to 1!") ch['augment.n_process'] = 1 dset_keys = ['train', 'val', 'test'] losses = {'class_ce': torch.nn.CrossEntropyLoss()} train_metrics = {} val_metrics = {'CE': metric_wrapper_pytorch(torch.nn.CrossEntropyLoss())} test_metrics = {'CE': metric_wrapper_pytorch(torch.nn.CrossEntropyLoss())} ######################### # Setup Parameters # ######################### params_dict = ch.get_params(losses=losses, train_metrics=train_metrics, val_metrics=val_metrics, add_self=ch['add_config_to_params']) params = Parameters(**params_dict) ################# # Setup IO # ################# # setup io load_sample = load_pickle load_fn = LoadPatches(load_fn=load_sample, patch_size=ch['patch_size'], **ch['data.load_patch']) datasets = {} for key in dset_keys: p = os.path.join(ch["data.path"], str(key)) datasets[key] = BaseExtendCacheDataset(p, load_fn=load_fn, **ch['data.kwargs']) ############################# # Setup Transformations # ############################# base_transforms = [] base_transforms.append(PopKeys("mapping")) train_transforms = [] if ch['augment.mode']: logger.info("Training augmentation enabled.") train_transforms.append( SpatialTransform(patch_size=ch['patch_size'], **ch['augment.kwargs'])) train_transforms.append(MirrorTransform(axes=(0, 1))) process = ch['augment.n_process'] if 'augment.n_process' in ch else 1 ######################### # Setup Datamanagers # ######################### datamanagers = {} for key in dset_keys: if key == 'train': trafos = base_transforms + train_transforms sampler = WeightedPrevalenceRandomSampler else: trafos = base_transforms sampler = SequentialSampler datamanagers[key] = BaseDataManager( data=datasets[key], batch_size=params.nested_get('batch_size'), n_process_augmentation=process, transforms=Compose(trafos), sampler_cls=sampler, ) ############################# # Initialize Experiment # ############################# experiment = \ PyTorchExperiment( params=params, model_cls=ClassNetwork, name=ch['exp.name'], save_path=ch['exp.dir'], optim_builder=create_optims_default_pytorch, trainer_cls=PyTorchNetworkTrainer, mixed_precision=ch['mixed_precision'], mixed_precision_kwargs={'verbose': False}, key_mapping={"input_batch": "data"}, **ch['exp.kwargs'], ) # save configurations ch.dump(os.path.join(experiment.save_path, 'config.json')) ################# # Training # ################# model = experiment.run(datamanagers['train'], datamanagers['val'], save_path_exp=experiment.save_path, ch=ch, metric_keys={'val_CE': ['pred', 'label']}, val_freq=1, verbose=True) ################ # Testing # ################ if test and datamanagers['test'] is not None: # metrics and metric_keys are used differently than in original # Delira implementation in order to support Evaluator # see mscl.training.predictor preds = experiment.test( network=model, test_data=datamanagers['test'], metrics=test_metrics, metric_keys={'CE': ['pred', 'label']}, verbose=True, ) softmax_fn = metric_wrapper_pytorch( partial(torch.nn.functional.softmax, dim=1)) preds = softmax_fn(preds[0]['pred']) labels = [d['label'] for d in datasets['test']] fpr, tpr, thresholds = roc_curve(labels, preds[:, 1]) roc_auc = auc(fpr, tpr) plt.plot(fpr, tpr, label='ROC (AUC = %0.2f)' % roc_auc) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Receiver operating characteristic example') plt.legend(loc="lower right") plt.savefig(os.path.join(experiment.save_path, 'test_roc.pdf')) plt.close() preds = experiment.test( network=model, test_data=datamanagers['val'], metrics=test_metrics, metric_keys={'CE': ['pred', 'label']}, verbose=True, ) preds = softmax_fn(preds[0]['pred']) labels = [d['label'] for d in datasets['val']] fpr, tpr, thresholds = roc_curve(labels, preds[:, 1]) roc_auc = auc(fpr, tpr) plt.plot(fpr, tpr, label='ROC (AUC = %0.2f)' % roc_auc) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Receiver operating characteristic example') plt.legend(loc="lower right") plt.savefig(os.path.join(experiment.save_path, 'best_val_roc.pdf')) plt.close() return experiment.save_path