def get_train_transform(patch_size): # we now create a list of transforms. These are not necessarily the best transforms to use for BraTS, this is just # to showcase some things tr_transforms = [] # the first thing we want to run is the SpatialTransform. It reduces the size of our data to patch_size and thus # also reduces the computational cost of all subsequent operations. All subsequent operations do not modify the # shape and do not transform spatially, so no border artifacts will be introduced # Here we use the new SpatialTransform_2 which uses a new way of parameterizing elastic_deform # We use all spatial transformations with a probability of 0.2 per sample. This means that 1 - (1 - 0.1) ** 3 = 27% # of samples will be augmented, the rest will just be cropped tr_transforms.append( SpatialTransform_2( patch_size, [i // 2 for i in patch_size], do_elastic_deform=True, deformation_scale=(0, 0.25), do_rotation=True, angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_seg=1, order_data=3, random_crop=True, p_el_per_sample=0.1, p_rot_per_sample=0.1, p_scale_per_sample=0.1)) # now we mirror along all axes tr_transforms.append(MirrorTransform(axes=(0, 1, 2))) # gamma transform. This is a nonlinear transformation of intensity values # (https://en.wikipedia.org/wiki/Gamma_correction) tr_transforms.append( GammaTransform(gamma_range=(0.5, 2), invert_image=False, per_channel=True, p_per_sample=0.15)) # we can also invert the image, apply the transform and then invert back tr_transforms.append( GammaTransform(gamma_range=(0.5, 2), invert_image=True, per_channel=True, p_per_sample=0.15)) # Gaussian Noise tr_transforms.append( GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.15)) # now we compose these transforms together tr_transforms = Compose(tr_transforms) return tr_transforms
def get_train_transform(patch_size): tr_transforms = [] tr_transforms.append( SpatialTransform_2( patch_size, [i // 2 for i in patch_size], do_elastic_deform=True, deformation_scale=(0, 0.05), do_rotation=True, angle_x=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi), angle_y=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi), angle_z=(-5 / 360. * 2 * np.pi, 5 / 360. * 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=-2.34, border_mode_seg='constant', border_cval_seg=0)) tr_transforms.append(MirrorTransform(axes=(0, 1, 2))) tr_transforms.append( BrightnessMultiplicativeTransform((0.7, 1.5), per_channel=True, p_per_sample=0.15)) tr_transforms.append( GammaTransform(gamma_range=(0.5, 2), invert_image=True, per_channel=True, p_per_sample=0.15)) tr_transforms.append( GaussianNoiseTransform(noise_variance=(0, 0.15), p_per_sample=0.15)) tr_transforms = Compose(tr_transforms) return tr_transforms
def get_train_transform(patch_size): """ data augmentation for training data, inspired by: https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py :param patch_size: shape of network's input :return list of transformations """ train_transforms = [] def rad(deg): return (-deg / 360 * 2 * np.pi, deg / 360 * 2 * np.pi) train_transforms.append( SpatialTransform_2( patch_size, (10, 10, 10), do_elastic_deform=True, deformation_scale=(0, 0.25), do_rotation=True, angle_z=rad(15), angle_x=(0, 0), angle_y=(0, 0), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_seg=1, random_crop=False, p_el_per_sample=0.2, p_rot_per_sample=0.2, p_scale_per_sample=0.2, )) train_transforms.append(MirrorTransform(axes=(0, 1))) train_transforms.append( BrightnessMultiplicativeTransform((0.7, 1.5), per_channel=True, p_per_sample=0.2)) train_transforms.append( GammaTransform(gamma_range=(0.2, 1.0), invert_image=False, per_channel=False, p_per_sample=0.2)) train_transforms.append( GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.2)) train_transforms.append( GaussianBlurTransform(blur_sigma=(0.2, 1.0), different_sigma_per_channel=False, p_per_channel=0.0, p_per_sample=0.2)) return Compose(train_transforms)
def _make_training_transforms(self): if self.no_data_augmentation: print("No data augmentation will be performed during training!") return [] patch_size = self.patch_size[::-1] # (x, y, z) order rot_angle_x = self.training_augmentation_args.get('angle_x', 15) rot_angle_y = self.training_augmentation_args.get('angle_y', 15) rot_angle_z = self.training_augmentation_args.get('angle_z', 15) p_per_sample = self.training_augmentation_args.get( 'p_per_sample', 0.15) train_transforms = [ SpatialTransform_2( patch_size, patch_size // 2, do_elastic_deform=self.training_augmentation_args.get( 'do_elastic_deform', True), deformation_scale=self.training_augmentation_args.get( 'deformation_scale', (0, 0.25)), do_rotation=self.training_augmentation_args.get( 'do_rotation', True), angle_x=(-rot_angle_x / 360. * 2 * np.pi, rot_angle_x / 360. * 2 * np.pi), angle_y=(-rot_angle_y / 360. * 2 * np.pi, rot_angle_y / 360. * 2 * np.pi), angle_z=(-rot_angle_z / 360. * 2 * np.pi, rot_angle_z / 360. * 2 * np.pi), do_scale=self.training_augmentation_args.get('do_scale', True), scale=self.training_augmentation_args.get( 'scale', (0.75, 1.25)), border_mode_data='nearest', border_cval_data=0, order_data=3, # border_mode_seg='nearest', border_cval_seg=0, # order_seg=0, random_crop=False, p_el_per_sample=self.training_augmentation_args.get( 'p_el_per_sample', 0.5), p_rot_per_sample=self.training_augmentation_args.get( 'p_rot_per_sample', 0.5), p_scale_per_sample=self.training_augmentation_args.get( 'p_scale_per_sample', 0.5)) ] if self.training_augmentation_args.get("do_mirror", False): train_transforms.append(MirrorTransform(axes=(0, 1, 2))) train_transforms.append( BrightnessMultiplicativeTransform( self.training_augmentation_args.get('brightness_range', (0.7, 1.5)), per_channel=True, p_per_sample=p_per_sample)) train_transforms.append( GaussianNoiseTransform( noise_variance=self.training_augmentation_args.get( 'gaussian_noise_variance', (0, 0.05)), p_per_sample=p_per_sample)) train_transforms.append( GammaTransform(gamma_range=self.training_augmentation_args.get( 'gamma_range', (0.5, 2)), invert_image=False, per_channel=True, p_per_sample=p_per_sample)) print("train_transforms\n", train_transforms) return train_transforms
def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, pin_memory=True, seeds_train=None, seeds_val=None, regions=None): assert params.get( 'mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) patch_size_spatial = patch_size[1:] else: patch_size_spatial = patch_size # Set order_data=0 and order_seg=0 for some more speed for cascade??? tr_transforms.append( SpatialTransform(patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis"))) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations" ) and not None and params.get( "cascade_do_cascade_augmentations"): # Remove the following transforms to remove cascade DA ?? tr_transforms.append( ApplyRandomBinaryOperatorTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get( "cascade_random_binary_transform_p"), key="data", strel_size=params.get( "cascade_random_binary_transform_size"))) tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get( "cascade_remove_conn_comp_max_size_percent_threshold"), dont_do_if_covers_more_than_X_percent=params.get( "cascade_remove_conn_comp_fill_with_other_class_p"))) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) # from batchgenerators.dataloading import SingleThreadedAugmenter # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) # import IPython;IPython.embed() batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): val_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val
def get_moreDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None, use_nondetMultiThreadedAugmenter: bool = False): assert params.get( 'mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): ignore_axes = (0, ) tr_transforms.append(Convert3DTo2DTransform()) patch_size_spatial = patch_size[1:] else: patch_size_spatial = patch_size ignore_axes = None tr_transforms.append( SpatialTransform(patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), p_rot_per_axis=params.get("rotation_p_per_axis"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis"))) if params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color # channel gets in the way tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) tr_transforms.append( GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5)) tr_transforms.append( BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15)) if params.get("do_additive_brightness"): tr_transforms.append( BrightnessTransform( params.get("additive_brightness_mu"), params.get("additive_brightness_sigma"), True, p_per_sample=params.get("additive_brightness_p_per_sample"), p_per_channel=params.get("additive_brightness_p_per_channel"))) tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15)) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.25, ignore_axes=ignore_axes)) tr_transforms.append( GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=0.1)) # inverted gamma if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror") or params.get("mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get( "cascade_do_cascade_augmentations") is not None and params.get( "cascade_do_cascade_augmentations"): if params.get("cascade_random_binary_transform_p") > 0: tr_transforms.append( ApplyRandomBinaryOperatorTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get( "cascade_random_binary_transform_p"), key="data", strel_size=params.get( "cascade_random_binary_transform_size"), p_per_label=params.get( "cascade_random_binary_transform_p_per_label"))) if params.get("cascade_remove_conn_comp_p") > 0: tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get( "cascade_remove_conn_comp_max_size_percent_threshold" ), dont_do_if_covers_more_than_X_percent=params.get( "cascade_remove_conn_comp_fill_with_other_class_p") )) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None tr_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) if use_nondetMultiThreadedAugmenter: if NonDetMultiThreadedAugmenter is None: raise RuntimeError( 'NonDetMultiThreadedAugmenter is not yet available') batchgenerator_train = NonDetMultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) else: batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) # import IPython;IPython.embed() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): val_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None val_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) if use_nondetMultiThreadedAugmenter: if NonDetMultiThreadedAugmenter is None: raise RuntimeError( 'NonDetMultiThreadedAugmenter is not yet available') batchgenerator_val = NonDetMultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) else: batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms) return batchgenerator_train, batchgenerator_val
def get_train_transforms(self) -> List[AbstractTransform]: # used for transpost and rot90 matching_axes = np.array( [sum([i == j for j in self.patch_size]) for i in self.patch_size]) valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0]) tr_transforms = [] if self.data_aug_params['selected_seg_channels'] is not None: tr_transforms.append( SegChannelSelectionTransform( self.data_aug_params['selected_seg_channels'])) if self.do_dummy_2D_aug: ignore_axes = (0, ) tr_transforms.append(Convert3DTo2DTransform()) patch_size_spatial = self.patch_size[1:] else: patch_size_spatial = self.patch_size ignore_axes = None tr_transforms.append( SpatialTransform( patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=False, do_rotation=True, angle_x=self.data_aug_params["rotation_x"], angle_y=self.data_aug_params["rotation_y"], angle_z=self.data_aug_params["rotation_z"], p_rot_per_axis=0.5, do_scale=True, scale=self.data_aug_params['scale_range'], border_mode_data="constant", border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=-1, order_seg=1, random_crop=False, p_el_per_sample=0.2, p_scale_per_sample=0.2, p_rot_per_sample=0.4, independent_scale_for_each_axis=True, )) if self.do_dummy_2D_aug: tr_transforms.append(Convert2DTo3DTransform()) if np.any(matching_axes > 1): tr_transforms.append( Rot90Transform((0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5), ) if np.any(matching_axes > 1): tr_transforms.append( TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5)) tr_transforms.append( OneOfTransform([ MedianFilterTransform((2, 8), same_for_each_channel=False, p_per_sample=0.2, p_per_channel=0.5), GaussianBlurTransform((0.3, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5) ])) tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1)) tr_transforms.append( BrightnessTransform(0, 0.5, per_channel=True, p_per_sample=0.1, p_per_channel=0.5)) tr_transforms.append( OneOfTransform([ ContrastAugmentationTransform(contrast_range=(0.5, 2), preserve_range=True, per_channel=True, data_key='data', p_per_sample=0.2, p_per_channel=0.5), ContrastAugmentationTransform(contrast_range=(0.5, 2), preserve_range=False, per_channel=True, data_key='data', p_per_sample=0.2, p_per_channel=0.5), ])) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.25, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.15, ignore_axes=ignore_axes)) tr_transforms.append( GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1)) tr_transforms.append( GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1)) if self.do_mirroring: tr_transforms.append(MirrorTransform(self.mirror_axes)) tr_transforms.append( BlankRectangleTransform([[max(1, p // 10), p // 3] for p in self.patch_size], rectangle_value=np.mean, num_rectangles=(1, 5), force_square=False, p_per_sample=0.4, p_per_channel=0.5)) tr_transforms.append( BrightnessGradientAdditiveTransform( lambda x, y: np.exp( np.random.uniform(np.log(x[y] // 6), np.log(x[y]))), (-0.5, 1.5), max_strength=lambda x, y: np.random.uniform(-5, -1) if np.random.uniform() < 0.5 else np.random.uniform(1, 5), mean_centered=False, same_for_all_channels=False, p_per_sample=0.3, p_per_channel=0.5)) tr_transforms.append( LocalGammaTransform( lambda x, y: np.exp( np.random.uniform(np.log(x[y] // 6), np.log(x[y]))), (-0.5, 1.5), lambda: np.random.uniform(0.01, 0.8) if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4), same_for_all_channels=False, p_per_sample=0.3, p_per_channel=0.5)) tr_transforms.append( SharpeningTransform(strength=(0.1, 1), same_for_each_channel=False, p_per_sample=0.2, p_per_channel=0.5)) if any(self.use_mask_for_norm.values()): tr_transforms.append( MaskTransform(self.use_mask_for_norm, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if self.data_aug_params["move_last_seg_chanel_to_data"]: all_class_labels = np.arange(1, self.num_classes) tr_transforms.append( MoveSegAsOneHotToData(1, all_class_labels, 'seg', 'data')) if self.data_aug_params["cascade_do_cascade_augmentations"]: tr_transforms.append( ApplyRandomBinaryOperatorTransform(channel_idx=list( range(-len(all_class_labels), 0)), p_per_sample=0.4, key="data", strel_size=(1, 8), p_per_label=1)) tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(all_class_labels), 0)), key="data", p_per_sample=0.2, fill_with_other_class_p=0.15, dont_do_if_covers_more_than_X_percent=0)) tr_transforms.append(RenameTransform('seg', 'target', True)) if self.regions is not None: tr_transforms.append( ConvertSegmentationToRegionsTransform(self.regions, 'target', 'target')) if self.deep_supervision_scales is not None: tr_transforms.append( DownsampleSegForDSTransform2(self.deep_supervision_scales, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) return tr_transforms
# in order to use neptune logging: # export NEPTUNE_API_TOKEN = '...' !!! logging.getLogger().setLevel('INFO') source_files = [__file__] if hparams.config: source_files.append(hparams.config) neptune_logger = NeptuneLogger(project_name=hparams.neptune_project, params=vars(hparams), experiment_name=hparams.experiment_name, tags=[hparams.experiment_name], upload_source_files=source_files) tb_logger = loggers.TensorBoardLogger(hparams.log_dir) transform = Compose([ BrightnessTransform(mu=0.0, sigma=0.3, data_key='data'), GammaTransform(gamma_range=(0.7, 1.3), data_key='data'), ContrastAugmentationTransform(contrast_range=(0.3, 1.7), data_key='data') ]) with open(hparams.train_set, 'r') as keyfile: train_keys = [l.strip() for l in keyfile.readlines()] print(train_keys) with open(hparams.val_set, 'r') as keyfile: val_keys = [l.strip() for l in keyfile.readlines()] print(val_keys) train_ds = MedDataset(hparams.data_path, train_keys, hparams.patches_per_subject, hparams.patch_size,
border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_seg=1, order_data=3, random_crop=False, p_el_per_sample=0.1, p_rot_per_sample=0.1, p_scale_per_sample=0.1) my_transforms.append(spatial_transform) my_transforms.append(MirrorTransform(axes=(0, 1, 2))) my_transforms.append( GammaTransform(gamma_range=(0.7, 1.), invert_image=False, per_channel=True, p_per_sample=0.1)) all_transforms = Compose(my_transforms) train_loader = SingleThreadedAugmenter( batchgen, all_transforms ) #data loader for training, applying on the fly transformation # add other data loaders test_loader = torch.utils.data.DataLoader( dataset_test, batch_size=1, shuffle=False, num_workers=0, )
}, up_kwargs={ 'attention': True }, encode_block=ResBlockStack, encode_kwargs_fn=encode_kwargs_fn, decode_block=ResBlock).cuda() patch_size = (160, 160, 80) optimizer = optim.Adam(model.parameters(), lr=1e-4) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=30) tr_transform = Compose([ GammaTransform((0.9, 1.1)), ContrastAugmentationTransform((0.9, 1.1)), BrightnessMultiplicativeTransform((0.9, 1.1)), MirrorTransform(axes=[0]), SpatialTransform_2( patch_size, (90, 90, 50), random_crop=True, do_elastic_deform=True, deformation_scale=(0, 0.05), do_rotation=True, angle_x=(-0.1 * np.pi, 0.1 * np.pi), angle_y=(0, 0), angle_z=(0, 0), do_scale=True, scale=(0.9, 1.1),
from torchvision import transforms train_transform = transforms.Compose([ # mt_transforms.CenterCrop2D((200, 200)), mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0), sigma_range=(3.5, 4.0), p=0.3), mt_transforms.RandomAffine(degrees=4.6, scale=(0.98, 1.02), translate=(0.03, 0.03)), mt_transforms.RandomTensorChannelShift((-0.10, 0.10)), mt_transforms.ToTensor() # mt_transforms.NormalizeInstance(), ]) gamma_t = GammaTransform(data_key="img", gamma_range=(0.1, 10)) mirror_t = MirrorTransform(data_key="img", label_key="seg") spatial_t = SpatialTransform(patch_size=(8, 8, 8), data_key="img", label_key="seg") gauss_noise_t = GaussianNoiseTransform(data_key="img", noise_variance=(0, 1)) zoom_t = ZoomTransform(zoom_factors=2, data_key="img") def show_basic(x, gt, info=None): if info is not None: print("Test for " + info)
def run(self, img_data, seg_data): # Define label for segmentation for segmentation augmentation if self.seg_augmentation: seg_label = "seg" else: seg_label = "class" # Create a parser for the batchgenerators module data_generator = DataParser(img_data, seg_data, seg_label) # Initialize empty transform list transforms = [] # Add mirror augmentation if self.mirror: aug_mirror = MirrorTransform(axes=self.config_mirror_axes) transforms.append(aug_mirror) # Add contrast augmentation if self.contrast: aug_contrast = ContrastAugmentationTransform( self.config_contrast_range, preserve_range=self.config_contrast_preserverange, per_channel=self.coloraug_per_channel, p_per_sample=self.config_p_per_sample) transforms.append(aug_contrast) # Add brightness augmentation if self.brightness: aug_brightness = BrightnessMultiplicativeTransform( self.config_brightness_range, per_channel=self.coloraug_per_channel, p_per_sample=self.config_p_per_sample) transforms.append(aug_brightness) # Add gamma augmentation if self.gamma: aug_gamma = GammaTransform(self.config_gamma_range, invert_image=False, per_channel=self.coloraug_per_channel, retain_stats=True, p_per_sample=self.config_p_per_sample) transforms.append(aug_gamma) # Add gaussian noise augmentation if self.gaussian_noise: aug_gaussian_noise = GaussianNoiseTransform( self.config_gaussian_noise_range, p_per_sample=self.config_p_per_sample) transforms.append(aug_gaussian_noise) # Add spatial transformations as augmentation # (rotation, scaling, elastic deformation) if self.rotations or self.scaling or self.elastic_deform or \ self.cropping: # Identify patch shape (full image or cropping) if self.cropping: patch_shape = self.cropping_patch_shape else: patch_shape = img_data[0].shape[0:-1] # Assembling the spatial transformation aug_spatial_transform = SpatialTransform( patch_shape, [i // 2 for i in patch_shape], do_elastic_deform=self.elastic_deform, alpha=self.config_elastic_deform_alpha, sigma=self.config_elastic_deform_sigma, do_rotation=self.rotations, angle_x=self.config_rotations_angleX, angle_y=self.config_rotations_angleY, angle_z=self.config_rotations_angleZ, do_scale=self.scaling, scale=self.config_scaling_range, border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_data=3, order_seg=0, p_el_per_sample=self.config_p_per_sample, p_rot_per_sample=self.config_p_per_sample, p_scale_per_sample=self.config_p_per_sample, random_crop=self.cropping) # Append spatial transformation to transformation list transforms.append(aug_spatial_transform) # Compose the batchgenerators transforms all_transforms = Compose(transforms) # Assemble transforms into a augmentation generator augmentation_generator = SingleThreadedAugmenter( data_generator, all_transforms) # Perform the data augmentation x times (x = cycles) aug_img_data = None aug_seg_data = None for i in range(0, self.cycles): # Run the computation process for the data augmentations augmentation = next(augmentation_generator) # Access augmentated data from the batchgenerators data structure if aug_img_data is None and aug_seg_data is None: aug_img_data = augmentation["data"] aug_seg_data = augmentation[seg_label] # Concatenate the new data augmentated data with the cached data else: aug_img_data = np.concatenate( (augmentation["data"], aug_img_data), axis=0) aug_seg_data = np.concatenate( (augmentation[seg_label], aug_seg_data), axis=0) # Transform data from channel-first back to channel-last structure # Data structure channel-first 3D: (batch, channel, x, y, z) # Data structure channel-last 3D: (batch, x, y, z, channel) aug_img_data = np.moveaxis(aug_img_data, 1, -1) aug_seg_data = np.moveaxis(aug_seg_data, 1, -1) # Return augmentated image and segmentation data return aug_img_data, aug_seg_data