def setup_augmentation(self): c = self.config transforms_train = [] for t in sorted(c.transforms_train.keys()): if c.transforms_train[t]["active"]: cls = c.transforms_train[t]["type"] kwargs = c.transforms_train[t]["kwargs"] transforms_train.append(cls(**kwargs)) self.augmenter_train = c.augmenter_train(self.generator_train, Compose(transforms_train), **c.augmenter_train_kwargs) transforms_val = [] for t in sorted(c.transforms_val.keys()): if c.transforms_val[t]["active"]: cls = c.transforms_val[t]["type"] kwargs = c.transforms_val[t]["kwargs"] transforms_val.append(cls(**kwargs)) self.augmenter_val = c.augmenter_val(self.generator_val, Compose(transforms_val), **c.augmenter_val_kwargs) self.augmenter_test = c.augmenter_val(self.generator_test, Compose(transforms_val), **c.augmenter_val_kwargs)
def create_data_gen_train(patient_data_train, BATCH_SIZE, num_classes, num_workers=5, num_cached_per_worker=2, do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.), do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi), do_scale=True, scale_range=(0.75, 1.25), seeds=None): if seeds is None: seeds = [None]*num_workers elif seeds == 'range': seeds = range(num_workers) else: assert len(seeds) == num_workers data_gen_train = BatchGenerator_2D(patient_data_train, BATCH_SIZE, num_batches=None, seed=False, PATCH_SIZE=(352, 352)) tr_transforms = [] tr_transforms.append(Mirror((2, 3))) tr_transforms.append(RndTransform(SpatialTransform((352, 352), list(np.array((352, 352))//2), do_elastic_transform, alpha, sigma, do_rotation, a_x, a_y, a_z, do_scale, scale_range, 'constant', 0, 3, 'constant', 0, 0, random_crop=False), prob=0.67, alternative_transform=RandomCropTransform((352, 352)))) tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), seg_channel=0, output_key='seg_onehot')) tr_composed = Compose(tr_transforms) tr_mt_gen = MultiThreadedAugmenter(data_gen_train, tr_composed, num_workers, num_cached_per_worker, seeds) tr_mt_gen.restart() return tr_mt_gen
def get_transforms(mode="train", target_size=128): tranform_list = [] if mode == "train": tranform_list = [# CenterCropTransform(crop_size=target_size), ResizeTransform(target_size=target_size, order=1), MirrorTransform(axes=(1,)), SpatialTransform(patch_size=(target_size,target_size), random_crop=False, patch_center_dist_from_border=target_size // 2, do_elastic_deform=True, alpha=(0., 1000.), sigma=(40., 60.), do_rotation=True, p_rot_per_sample=0.5, angle_x=(-0.1, 0.1), angle_y=(0, 1e-8), angle_z=(0, 1e-8), scale=(0.5, 1.9), p_scale_per_sample=0.5, border_mode_data="nearest", border_mode_seg="nearest"), ] elif mode == "val": tranform_list = [CenterCropTransform(crop_size=target_size), ResizeTransform(target_size=target_size, order=1), ] elif mode == "test": tranform_list = [CenterCropTransform(crop_size=target_size), ResizeTransform(target_size=target_size, order=1), ] tranform_list.append(NumpyToTensor()) return Compose(tranform_list)
def test_image_pipeline_and_pin_memory(self): ''' This just should not crash :return: ''' try: import torch except ImportError: '''dont test if torch is not installed''' return from batchgenerators.transforms import MirrorTransform, NumpyToTensor, TransposeAxesTransform, Compose tr_transforms = [] tr_transforms.append(MirrorTransform()) tr_transforms.append( TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5)) tr_transforms.append(NumpyToTensor(keys='data', cast_to='float')) composed = Compose(tr_transforms) dl = self.dl_images mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True) for _ in range(50): res = mt.next() assert isinstance(res['data'], torch.Tensor) assert res['data'].is_pinned() # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent # the success of the test but it does not look pretty) sleep(2)
def get_train_transform(patch_size): tr_transforms = [] tr_transforms.append( SpatialTransform_2( None, [i // 2 for i in patch_size], do_elastic_deform=False, deformation_scale=(0, 0.25), do_rotation=True, angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), do_scale=False, scale=(0.8, 1.2), border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_seg=1, order_data=3, random_crop=False, p_el_per_sample=0.3, p_rot_per_sample=0.3, p_scale_per_sample=0.3)) tr_transforms.append( RndTransform(MirrorTransform(axes=(0, 1, 2)), prob=0.3)) tr_transforms = Compose(transforms=tr_transforms) return tr_transforms
def get_train_val_generators(fold): tr_keys, te_keys = get_split(fold, split_seed) train_data = {i: dataset[i] for i in tr_keys} val_data = {i: dataset[i] for i in te_keys} data_gen_train = create_data_gen_train( train_data, BATCH_SIZE, num_classes, INPUT_PATCH_SIZE, num_workers=num_workers, do_elastic_transform=True, alpha=(0., 350.), sigma=(14., 17.), do_rotation=True, a_x=(0, 2. * np.pi), a_y=(-0.000001, 0.00001), a_z=(-0.000001, 0.00001), do_scale=True, scale_range=(0.7, 1.3), seeds=workers_seeds) # new se has no brain mask data_gen_validation = BatchGenerator(val_data, BATCH_SIZE, num_batches=None, seed=False, PATCH_SIZE=INPUT_PATCH_SIZE) val_transforms = [] val_transforms.append( ConvertSegToOnehotTransform(range(4), 0, 'seg_onehot')) data_gen_validation = MultiThreadedAugmenter(data_gen_validation, Compose(val_transforms), 1, 2, [0]) return data_gen_train, data_gen_validation
def test_future(self): info = self.validate_make_default_info() info["coords"]["Metric"] = info["coords"]["Metric"] + [ "Future KL", "Future Reconstruction NLL", "Future Reconstruction Dice", "Prior Maximum Dice", "Prior Best Volume Dice" ] # our regular generators only produce 1 timestep, so we create this here manually if self.config.test_on_val: test_data = self.data_val else: test_data = self.data_test generator = self.config.generator_val( test_data, self.config.batch_size_val, 3, number_of_threads_in_multithreaded=self.config. augmenter_val_kwargs.num_processes) transforms = [] for t in sorted(self.config.transforms_val.keys()): if self.config.transforms_val[t]["active"]: cls = self.config.transforms_val[t]["type"] kwargs = self.config.transforms_val[t]["kwargs"] transforms.append(cls(**kwargs)) augmenter = self.config.augmenter_val( generator, Compose(transforms), **self.config.augmenter_val_kwargs) test_scores, info = self.test_inner(augmenter, [], info, future=True) test_scores = np.array(test_scores) self.elog.save_numpy_data(test_scores, "test_future.npy") self.elog.save_dict(info, "test_future.json")
def get_train_transform(patch_size): """ data augmentation for training data, inspired by: https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py :param patch_size: shape of network's input :return list of transformations """ train_transforms = [] def rad(deg): return (-deg / 360 * 2 * np.pi, deg / 360 * 2 * np.pi) train_transforms.append( SpatialTransform_2( patch_size, (10, 10, 10), do_elastic_deform=True, deformation_scale=(0, 0.25), do_rotation=True, angle_z=rad(15), angle_x=(0, 0), angle_y=(0, 0), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_seg=1, random_crop=False, p_el_per_sample=0.2, p_rot_per_sample=0.2, p_scale_per_sample=0.2, )) train_transforms.append(MirrorTransform(axes=(0, 1))) train_transforms.append( BrightnessMultiplicativeTransform((0.7, 1.5), per_channel=True, p_per_sample=0.2)) train_transforms.append( GammaTransform(gamma_range=(0.2, 1.0), invert_image=False, per_channel=False, p_per_sample=0.2)) train_transforms.append( GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.2)) train_transforms.append( GaussianBlurTransform(blur_sigma=(0.2, 1.0), different_sigma_per_channel=False, p_per_channel=0.0, p_per_sample=0.2)) return Compose(train_transforms)
def get_no_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1): """ use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation :param dataloader_train: :param dataloader_val: :param patch_size: :param params: :param border_val_seg: :return: """ tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=range(params.get('num_threads')), pin_memory=True) batchgenerator_train.restart() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) val_transforms.append(RenameTransform('seg', 'target', True)) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads')//2, 1), params.get("num_cached_per_thread"), seeds=range(max(params.get('num_threads')//2, 1)), pin_memory=True) batchgenerator_val.restart() return batchgenerator_train, batchgenerator_val
def get_train_transform(patch_size): # we now create a list of transforms. These are not necessarily the best transforms to use for BraTS, this is just # to showcase some things tr_transforms = [] # the first thing we want to run is the SpatialTransform. It reduces the size of our data to patch_size and thus # also reduces the computational cost of all subsequent operations. All subsequent operations do not modify the # shape and do not transform spatially, so no border artifacts will be introduced # Here we use the new SpatialTransform_2 which uses a new way of parameterizing elastic_deform # We use all spatial transformations with a probability of 0.2 per sample. This means that 1 - (1 - 0.1) ** 3 = 27% # of samples will be augmented, the rest will just be cropped tr_transforms.append( SpatialTransform_2( patch_size, [i // 2 for i in patch_size], do_elastic_deform=True, deformation_scale=(0, 0.25), do_rotation=True, angle_x=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), angle_y=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), angle_z=(-15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi), do_scale=True, scale=(0.75, 1.25), border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_seg=1, order_data=3, random_crop=True, p_el_per_sample=0.1, p_rot_per_sample=0.1, p_scale_per_sample=0.1)) # now we mirror along all axes tr_transforms.append(MirrorTransform(axes=(0, 1, 2))) # gamma transform. This is a nonlinear transformation of intensity values # (https://en.wikipedia.org/wiki/Gamma_correction) tr_transforms.append( GammaTransform(gamma_range=(0.5, 2), invert_image=False, per_channel=True, p_per_sample=0.15)) # we can also invert the image, apply the transform and then invert back tr_transforms.append( GammaTransform(gamma_range=(0.5, 2), invert_image=True, per_channel=True, p_per_sample=0.15)) # Gaussian Noise tr_transforms.append( GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.15)) # now we compose these transforms together tr_transforms = Compose(tr_transforms) return tr_transforms
def setUp(self) -> None: from delira.data_loading.numba_transform import NumbaTransform, \ NumbaCompose self._basic_zoom_trafo = ZoomTransform(3) self._numba_zoom_trafo = NumbaTransform(ZoomTransform, zoom_factors=3) self._basic_pad_trafo = PadTransform(new_size=(30, 30)) self._numba_pad_trafo = NumbaTransform(PadTransform, new_size=(30, 30)) self._basic_compose_trafo = Compose( [self._basic_pad_trafo, self._basic_zoom_trafo]) self._numba_compose_trafo = NumbaCompose( [self._basic_pad_trafo, self._basic_zoom_trafo]) self._input = {"data": np.random.rand(10, 1, 24, 24)}
def get_validation_transforms(self): val_transforms = [] if self.params.get("selected_data_channels"): val_transforms.append( DataChannelSelectionTransform( self.params.get("selected_data_channels"))) if self.params.get("selected_seg_channels"): val_transforms.append( SegChannelSelectionTransform( self.params.get("selected_seg_channels"))) val_transforms.append(CenterCropTransform(self.patch_size)) val_transforms.append(RemoveLabelTransform(-1, 0)) val_transforms.append(RenameTransform('seg', 'target', True)) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) return Compose(val_transforms)
def test_image_pipeline(self): ''' This just should not crash :return: ''' from batchgenerators.transforms import MirrorTransform, TransposeAxesTransform, Compose tr_transforms = [] tr_transforms.append(MirrorTransform()) tr_transforms.append(TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5)) composed = Compose(tr_transforms) dl = self.dl_images mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, False) for _ in range(50): res = mt.next() # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent # the success of the test but it does not look pretty) sleep(2)
def get_augmentation(patch_size, params=default_3D_augmentation_params, border_val_seg=-1): print(f'patch size after augmentation {patch_size}') tr_transforms = [] tr_transforms.append( SpatialTransform(patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"))) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) tr_transforms = Compose(tr_transforms) return tr_transforms
def create_data_gen_train(patient_data_train, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE, contrast_range=(0.75, 1.5), gamma_range = (0.6, 2), num_workers=5, num_cached_per_worker=3, do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.), do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi), do_scale=True, scale_range=(0.75, 1.25), seeds=None): if seeds is None: seeds = [None]*num_workers elif seeds == 'range': seeds = range(num_workers) else: assert len(seeds) == num_workers data_gen_train = BatchGenerator3D_random_sampling(patient_data_train, BATCH_SIZE, num_batches=None, seed=False, patch_size=(160, 192, 160), convert_labels=True) tr_transforms = [] tr_transforms.append(DataChannelSelectionTransform([0, 1, 2, 3])) tr_transforms.append(GenerateBrainMaskTransform()) tr_transforms.append(MirrorTransform()) tr_transforms.append(SpatialTransform(INPUT_PATCH_SIZE, list(np.array(INPUT_PATCH_SIZE)//2.), do_elastic_deform=do_elastic_transform, alpha=alpha, sigma=sigma, do_rotation=do_rotation, angle_x=a_x, angle_y=a_y, angle_z=a_z, do_scale=do_scale, scale=scale_range, border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True)) tr_transforms.append(BrainMaskAwareStretchZeroOneTransform((-5, 5), True)) tr_transforms.append(ContrastAugmentationTransform(contrast_range, True)) tr_transforms.append(GammaTransform(gamma_range, False)) tr_transforms.append(BrainMaskAwareStretchZeroOneTransform(per_channel=True)) tr_transforms.append(BrightnessTransform(0.0, 0.1, True)) tr_transforms.append(SegChannelSelectionTransform([0])) tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), 0, "seg_onehot")) gen_train = MultiThreadedAugmenter(data_gen_train, Compose(tr_transforms), num_workers, num_cached_per_worker, seeds) gen_train.restart() return gen_train
def get_data_augmenter(data, batch_size=1, mode=DataLoader.Mode.NORMAL, volumetric=True, normalization_range=None, vector_generator=None, input_shape=None, sample_count=1, transforms=None, threads=1, seed=None): transforms = [] if transforms is None else transforms threads = min(int(np.ceil(len(data) / batch_size)), threads) loader = DataLoader(data=data, batch_size=batch_size, mode=mode, volumetric=volumetric, normalization_range=normalization_range, vector_generator=vector_generator, input_shape=input_shape, sample_count=sample_count, number_of_threads_in_multithreaded=threads, seed=seed) transforms = transforms + [PrepareForTF()] return MultiThreadedAugmenter(loader, Compose(transforms), threads)
def get_valid_transform(patch_size): """ data augmentation for validation data inspired by: https://github.com/MIC-DKFZ/batchgenerators/blob/master/batchgenerators/examples/brats2017/brats2017_dataloader_3D.py :param patch_size: shape of network's input :return list of transformations """ train_transforms = [] train_transforms.append( SpatialTransform_2(patch_size, patch_size, do_elastic_deform=False, deformation_scale=(0, 0), do_rotation=False, angle_x=(0, 0), angle_y=(0, 0), angle_z=(0, 0), do_scale=False, scale=(1.0, 1.0), border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_seg=1, order_data=3, random_crop=True, p_el_per_sample=0.1, p_rot_per_sample=0.1, p_scale_per_sample=0.1)) train_transforms.append(MirrorTransform(axes=(0, 1))) return Compose(train_transforms)
def get_transforms(mode="train", target_size=128): tranform_list = [] if mode == "train": tranform_list = [ # CenterCropTransform(crop_size=target_size), ResizeTransform(target_size=(target_size, target_size), order=1), MirrorTransform(axes=(1, )), ] elif mode == "val": tranform_list = [ #CenterCropTransform(crop_size=target_size), ResizeTransform(target_size=target_size, order=1), MirrorTransform(axes=(1, )), ] elif mode == "test": tranform_list = [ #CenterCropTransform(crop_size=target_size), ResizeTransform(target_size=target_size, order=1), MirrorTransform(axes=(1, )), ] tranform_list.append(NumpyToTensor()) return Compose(tranform_list)
from delira.training.backends import convert_torch_to_numpy from functools import partial from delira.data_loading import DataManager, SequentialSampler from delira_unet import UNetTorch from delira import set_debug_mode if __name__ == "__main__": checkpoint_path = "" data_path = "" save_path = "" set_debug_mode(True) transforms = Compose([ CopyTransform("data", "data_orig"), # HistogramEqualization(), RangeTransform((-1, 1)), # AddGridTransform(), ]) img_size = (1024, 256) thresh = 0.5 print("Load Model") torch.jit.loat(checkpoint_path) model.eval() if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu")
def get_training_transforms(self): assert self.params.get( 'mirror' ) is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if self.params.get("selected_data_channels"): tr_transforms.append( DataChannelSelectionTransform( self.params.get("selected_data_channels"))) if self.params.get("selected_seg_channels"): tr_transforms.append( SegChannelSelectionTransform( self.params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if self.params.get("dummy_2D", False): ignore_axes = (0, ) tr_transforms.append(Convert3DTo2DTransform()) else: ignore_axes = None tr_transforms.append( SpatialTransform( self._spatial_transform_patch_size, patch_center_dist_from_border=None, do_elastic_deform=self.params.get("do_elastic"), alpha=self.params.get("elastic_deform_alpha"), sigma=self.params.get("elastic_deform_sigma"), do_rotation=self.params.get("do_rotation"), angle_x=self.params.get("rotation_x"), angle_y=self.params.get("rotation_y"), angle_z=self.params.get("rotation_z"), do_scale=self.params.get("do_scaling"), scale=self.params.get("scale_range"), order_data=self.params.get("order_data"), border_mode_data=self.params.get("border_mode_data"), border_cval_data=self.params.get("border_cval_data"), order_seg=self.params.get("order_seg"), border_mode_seg=self.params.get("border_mode_seg"), border_cval_seg=self.params.get("border_cval_seg"), random_crop=self.params.get("random_crop"), p_el_per_sample=self.params.get("p_eldef"), p_scale_per_sample=self.params.get("p_scale"), p_rot_per_sample=self.params.get("p_rot"), independent_scale_for_each_axis=self.params.get( "independent_scale_factor_for_each_axis"), )) if self.params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color # channel gets in the way tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15)) tr_transforms.append( GaussianBlurTransform((0.5, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5), ) tr_transforms.append( BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.3), p_per_sample=0.15)) if self.params.get("do_additive_brightness"): tr_transforms.append( BrightnessTransform( self.params.get("additive_brightness_mu"), self.params.get("additive_brightness_sigma"), True, p_per_sample=self.params.get( "additive_brightness_p_per_sample"), p_per_channel=self.params.get( "additive_brightness_p_per_channel"))) tr_transforms.append( ContrastAugmentationTransform(contrast_range=(0.65, 1.5), p_per_sample=0.15)) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.25, ignore_axes=ignore_axes), ) tr_transforms.append( GammaTransform(self.params.get("gamma_range"), True, True, retain_stats=self.params.get("gamma_retain_stats"), p_per_sample=0.15)) # inverted gamma if self.params.get("do_gamma"): tr_transforms.append( GammaTransform( self.params.get("gamma_range"), False, True, retain_stats=self.params.get("gamma_retain_stats"), p_per_sample=self.params["p_gamma"])) if self.params.get("do_mirror") or self.params.get("mirror"): tr_transforms.append( MirrorTransform(self.params.get("mirror_axes"))) if self.params.get("use_mask_for_norm"): use_mask_for_norm = self.params.get("use_mask_for_norm") tr_transforms.append( MaskTransform(use_mask_for_norm, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) return Compose(tr_transforms)
def get_training_transforms(self): assert self.params.get( 'mirror' ) is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if self.params.get("selected_data_channels"): tr_transforms.append( DataChannelSelectionTransform( self.params.get("selected_data_channels"))) if self.params.get("selected_seg_channels"): tr_transforms.append( SegChannelSelectionTransform( self.params.get("selected_seg_channels"))) if self.params.get("dummy_2D", False): # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append( SpatialTransform( self._spatial_transform_patch_size, patch_center_dist_from_border=None, do_elastic_deform=self.params.get("do_elastic"), alpha=self.params.get("elastic_deform_alpha"), sigma=self.params.get("elastic_deform_sigma"), do_rotation=self.params.get("do_rotation"), angle_x=self.params.get("rotation_x"), angle_y=self.params.get("rotation_y"), angle_z=self.params.get("rotation_z"), do_scale=self.params.get("do_scaling"), scale=self.params.get("scale_range"), order_data=self.params.get("order_data"), border_mode_data=self.params.get("border_mode_data"), border_cval_data=self.params.get("border_cval_data"), order_seg=self.params.get("order_seg"), border_mode_seg=self.params.get("border_mode_seg"), border_cval_seg=self.params.get("border_cval_seg"), random_crop=self.params.get("random_crop"), p_el_per_sample=self.params.get("p_eldef"), p_scale_per_sample=self.params.get("p_scale"), p_rot_per_sample=self.params.get("p_rot"), independent_scale_for_each_axis=self.params.get( "independent_scale_factor_for_each_axis"), )) if self.params.get("dummy_2D", False): tr_transforms.append(Convert2DTo3DTransform()) if self.params.get("do_gamma", False): tr_transforms.append( GammaTransform( self.params.get("gamma_range"), False, True, retain_stats=self.params.get("gamma_retain_stats"), p_per_sample=self.params["p_gamma"])) if self.params.get("do_mirror", False): tr_transforms.append( MirrorTransform(self.params.get("mirror_axes"))) if self.params.get("use_mask_for_norm"): use_mask_for_norm = self.params.get("use_mask_for_norm") tr_transforms.append( MaskTransform(use_mask_for_norm, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) return Compose(tr_transforms)
def get_insaneDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True): assert params.get( 'mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): ignore_axes = (0, ) tr_transforms.append(Convert3DTo2DTransform()) else: ignore_axes = None tr_transforms.append( SpatialTransform(patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis"))) if params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color # channel gets in the way tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15)) tr_transforms.append( GaussianBlurTransform((0.5, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5)) tr_transforms.append( BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3), p_per_sample=0.15)) tr_transforms.append( ContrastAugmentationTransform(contrast_range=(0.65, 1.5), p_per_sample=0.15)) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.25, ignore_axes=ignore_axes)) tr_transforms.append( GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=0.15)) # inverted gamma if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror") or params.get("mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations" ) and not None and params.get( "cascade_do_cascade_augmentations"): if params.get("cascade_random_binary_transform_p") > 0: tr_transforms.append( ApplyRandomBinaryOperatorTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get( "cascade_random_binary_transform_p"), key="data", strel_size=params.get( "cascade_random_binary_transform_size"))) if params.get("cascade_remove_conn_comp_p") > 0: tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get( "cascade_remove_conn_comp_max_size_percent_threshold" ), dont_do_if_covers_more_than_X_percent=params.get( "cascade_remove_conn_comp_fill_with_other_class_p") )) tr_transforms.append(RenameTransform('seg', 'target', True)) if deep_supervision_scales is not None: if soft_ds: assert classes is not None tr_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): val_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if deep_supervision_scales is not None: if soft_ds: assert classes is not None val_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val
def get_no_augmentation(dataloader_train, dataloader_val, params=default_3D_augmentation_params, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None): """ use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation """ tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None tr_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=range(params.get('num_threads')), pin_memory=pin_memory) batchgenerator_train.restart() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None val_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=range(max(params.get('num_threads') // 2, 1)), pin_memory=pin_memory) batchgenerator_val.restart() return batchgenerator_train, batchgenerator_val
def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, pin_memory=True, seeds_train=None, seeds_val=None, regions=None): assert params.get('mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append(SpatialTransform( patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis") )) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get("mask_was_used_for_normalization") tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations") and not None and params.get( "cascade_do_cascade_augmentations"): tr_transforms.append(ApplyRandomBinaryOperatorTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get("cascade_random_binary_transform_p"), key="data", strel_size=params.get("cascade_random_binary_transform_size"))) tr_transforms.append(RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"), dont_do_if_covers_more_than_X_percent=params.get("cascade_remove_conn_comp_fill_with_other_class_p"))) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) # from batchgenerators.dataloading import SingleThreadedAugmenter # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) # import IPython;IPython.embed() batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val
train_transforms = [MirrorTransform(axes=(0, 1))] post_transforms = [ ToTensor(keys=('data', 'label'), dtypes=(('data', torch.float32), ('label', torch.int64))) ] def init_fn(worker_id): seed = torch.utils.data._utils.worker._worker_info.seed import numpy as np seed32 = np.array(seed).astype(np.uint32) np.random.seed(seed32) train_dset = PytorchDatasetWrapper(ImageFolder(train_path)) train_transforms = Compose(pre_transforms + train_transforms + post_transforms) train_data = DataLoader(train_dset, shuffle=True, batch_size=32, num_workers=4, batch_transforms=train_transforms, collate_fn=numpy_collate, pin_memory=False, worker_init_fn=init_fn) val_dset = PytorchDatasetWrapper(ImageFolder(val_path)) val_transforms = Compose(pre_transforms + post_transforms) val_data = DataLoader(val_dset, shuffle=False, batch_size=32, num_workers=4,
def run(fold=0): print fold # ================================================================================================================= I_AM_FOLD = fold np.random.seed(65432) lasagne.random.set_rng(np.random.RandomState(98765)) sys.setrecursionlimit(2000) BATCH_SIZE = 2 INPUT_PATCH_SIZE =(128, 128, 128) num_classes=4 EXPERIMENT_NAME = "final" results_dir = os.path.join(paths.results_folder) if not os.path.isdir(results_dir): os.mkdir(results_dir) results_dir = os.path.join(results_dir, EXPERIMENT_NAME) if not os.path.isdir(results_dir): os.mkdir(results_dir) results_dir = os.path.join(results_dir, "fold%d"%I_AM_FOLD) if not os.path.isdir(results_dir): os.mkdir(results_dir) n_epochs = 300 lr_decay = np.float32(0.985) base_lr = np.float32(0.0005) n_batches_per_epoch = 100 n_test_batches = 10 n_feedbacks_per_epoch = 10. num_workers = 6 workers_seeds = [123, 1234, 12345, 123456, 1234567, 12345678] # ================================================================================================================= all_data = load_dataset() keys_sorted = np.sort(all_data.keys()) crossval_folds = KFold(len(all_data.keys()), n_folds=5, shuffle=True, random_state=123456) ctr = 0 for train_idx, test_idx in crossval_folds: print len(train_idx), len(test_idx) if ctr == I_AM_FOLD: train_keys = [keys_sorted[i] for i in train_idx] test_keys = [keys_sorted[i] for i in test_idx] break ctr += 1 train_data = {i:all_data[i] for i in train_keys} test_data = {i:all_data[i] for i in test_keys} data_gen_train = create_data_gen_train(train_data, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE, contrast_range=(0.75, 1.5), gamma_range = (0.8, 1.5), num_workers=num_workers, num_cached_per_worker=2, do_elastic_transform=True, alpha=(0., 1300.), sigma=(10., 13.), do_rotation=True, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi), do_scale=True, scale_range=(0.75, 1.25), seeds=workers_seeds) data_gen_validation = BatchGenerator3D_random_sampling(test_data, BATCH_SIZE, num_batches=None, seed=False, patch_size=INPUT_PATCH_SIZE, convert_labels=True) val_transforms = [] val_transforms.append(GenerateBrainMaskTransform()) val_transforms.append(BrainMaskAwareStretchZeroOneTransform(clip_range=(-5, 5), per_channel=True)) val_transforms.append(SegChannelSelectionTransform([0])) val_transforms.append(ConvertSegToOnehotTransform(range(4), 0, "seg_onehot")) val_transforms.append(DataChannelSelectionTransform([0, 1, 2, 3])) data_gen_validation = MultiThreadedAugmenter(data_gen_validation, Compose(val_transforms), 2, 2) x_sym = T.tensor5() seg_sym = T.matrix() net, seg_layer = build_net(x_sym, INPUT_PATCH_SIZE, num_classes, 4, 16, batch_size=BATCH_SIZE, do_instance_norm=True) output_layer_for_loss = net # add some weight decay l2_loss = lasagne.regularization.regularize_network_params(output_layer_for_loss, lasagne.regularization.l2) * 1e-5 # the distinction between prediction_train and test is important only if we enable dropout (batch norm/inst norm # does not use or save moving averages) prediction_train = lasagne.layers.get_output(output_layer_for_loss, x_sym, deterministic=False, batch_norm_update_averages=False, batch_norm_use_averages=False) loss_vec = - soft_dice_per_img_in_batch(prediction_train, seg_sym, BATCH_SIZE)[:, 1:] loss = loss_vec.mean() loss += l2_loss acc_train = T.mean(T.eq(T.argmax(prediction_train, axis=1), seg_sym.argmax(-1)), dtype=theano.config.floatX) prediction_test = lasagne.layers.get_output(output_layer_for_loss, x_sym, deterministic=True, batch_norm_update_averages=False, batch_norm_use_averages=False) loss_val = - soft_dice_per_img_in_batch(prediction_test, seg_sym, BATCH_SIZE)[:, 1:] loss_val = loss_val.mean() loss_val += l2_loss acc = T.mean(T.eq(T.argmax(prediction_test, axis=1), seg_sym.argmax(-1)), dtype=theano.config.floatX) # learning rate has to be a shared variable because we decrease it with every epoch params = lasagne.layers.get_all_params(output_layer_for_loss, trainable=True) learning_rate = theano.shared(base_lr) updates = lasagne.updates.adam(T.grad(loss, params), params, learning_rate=learning_rate, beta1=0.9, beta2=0.999) dc = hard_dice_per_img_in_batch(prediction_test, seg_sym.argmax(1), num_classes, BATCH_SIZE).mean(0) train_fn = theano.function([x_sym, seg_sym], [loss, acc_train, loss_vec], updates=updates) val_fn = theano.function([x_sym, seg_sym], [loss_val, acc, dc]) all_val_dice_scores=None all_training_losses = [] all_validation_losses = [] all_validation_accuracies = [] all_training_accuracies = [] val_dice_scores = [] epoch = 0 while epoch < n_epochs: if epoch == 100: data_gen_train = create_data_gen_train(train_data, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE, contrast_range=(0.85, 1.25), gamma_range = (0.8, 1.5), num_workers=6, num_cached_per_worker=2, do_elastic_transform=True, alpha=(0., 1000.), sigma=(10., 13.), do_rotation=True, a_x=(0., 2*np.pi), a_y=(-np.pi/8., np.pi/8.), a_z=(-np.pi/8., np.pi/8.), do_scale=True, scale_range=(0.85, 1.15), seeds=workers_seeds) if epoch == 175: data_gen_train = create_data_gen_train(train_data, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE, contrast_range=(0.9, 1.1), gamma_range = (0.85, 1.3), num_workers=6, num_cached_per_worker=2, do_elastic_transform=True, alpha=(0., 750.), sigma=(10., 13.), do_rotation=True, a_x=(0., 2*np.pi), a_y=(-0.00001, 0.00001), a_z=(-0.00001, 0.00001), do_scale=True, scale_range=(0.85, 1.15), seeds=workers_seeds) epoch_start_time = time.time() learning_rate.set_value(np.float32(base_lr* lr_decay**(epoch))) print "epoch: ", epoch, " learning rate: ", learning_rate.get_value() train_loss = 0 train_acc_tmp = 0 train_loss_tmp = 0 batch_ctr = 0 for data_dict in data_gen_train: data = data_dict["data"].astype(np.float32) seg = data_dict["seg_onehot"].astype(np.float32).transpose(0, 2, 3, 4, 1).reshape((-1, num_classes)) if batch_ctr != 0 and batch_ctr % int(np.floor(n_batches_per_epoch/n_feedbacks_per_epoch)) == 0: print "number of batches: ", batch_ctr, "/", n_batches_per_epoch print "training_loss since last update: ", \ train_loss_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch), " train accuracy: ", \ train_acc_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch) all_training_losses.append(train_loss_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch)) all_training_accuracies.append(train_acc_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch)) train_loss_tmp = 0 train_acc_tmp = 0 if len(val_dice_scores) > 0: all_val_dice_scores = np.concatenate(val_dice_scores, axis=0).reshape((-1, num_classes)) try: printLosses(all_training_losses, all_training_accuracies, all_validation_losses, all_validation_accuracies, os.path.join(results_dir, "%s.png" % EXPERIMENT_NAME), n_feedbacks_per_epoch, val_dice_scores=all_val_dice_scores, val_dice_scores_labels=["brain", "1", "2", "3", "4", "5"]) except: pass loss_vec, acc, l = train_fn(data, seg) loss = loss_vec.mean() train_loss += loss train_loss_tmp += loss train_acc_tmp += acc batch_ctr += 1 if batch_ctr >= n_batches_per_epoch: break all_training_losses.append(train_loss_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch)) all_training_accuracies.append(train_acc_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch)) train_loss /= n_batches_per_epoch print "training loss average on epoch: ", train_loss val_loss = 0 accuracies = [] valid_batch_ctr = 0 all_dice = [] for data_dict in data_gen_validation: data = data_dict["data"].astype(np.float32) seg = data_dict["seg_onehot"].astype(np.float32).transpose(0, 2, 3, 4, 1).reshape((-1, num_classes)) w = np.zeros(num_classes, dtype=np.float32) w[np.unique(seg.argmax(-1))] = 1 loss, acc, dice = val_fn(data, seg) dice[w==0] = 2 all_dice.append(dice) val_loss += loss accuracies.append(acc) valid_batch_ctr += 1 if valid_batch_ctr >= n_test_batches: break all_dice = np.vstack(all_dice) dice_means = np.zeros(num_classes) for i in range(num_classes): dice_means[i] = all_dice[all_dice[:, i]!=2, i].mean() val_loss /= n_test_batches print "val loss: ", val_loss print "val acc: ", np.mean(accuracies), "\n" print "val dice: ", dice_means print "This epoch took %f sec" % (time.time()-epoch_start_time) val_dice_scores.append(dice_means) all_validation_losses.append(val_loss) all_validation_accuracies.append(np.mean(accuracies)) all_val_dice_scores = np.concatenate(val_dice_scores, axis=0).reshape((-1, num_classes)) try: printLosses(all_training_losses, all_training_accuracies, all_validation_losses, all_validation_accuracies, os.path.join(results_dir, "%s.png" % EXPERIMENT_NAME), n_feedbacks_per_epoch, val_dice_scores=all_val_dice_scores, val_dice_scores_labels=["brain", "1", "2", "3", "4", "5"]) except: pass with open(os.path.join(results_dir, "%s_Params.pkl" % (EXPERIMENT_NAME)), 'w') as f: cPickle.dump(lasagne.layers.get_all_param_values(output_layer_for_loss), f) with open(os.path.join(results_dir, "%s_allLossesNAccur.pkl"% (EXPERIMENT_NAME)), 'w') as f: cPickle.dump([all_training_losses, all_training_accuracies, all_validation_losses, all_validation_accuracies, val_dice_scores], f) epoch += 1
p_rot_per_axis=params.get("rotation_p_per_axis"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis"))) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), pin_memory=True) train_loader = batchgenerator_train train_batch = next(train_loader) print( train_batch.keys() ) # dict_keys(['data', 'target']), each with torch.Size([2, 1, 112, 240, 272]) print((train_batch['seg'] - train_batch['weak_label']).sum()) train_batch = next(train_loader) print((train_batch['seg'] - train_batch['weak_label']).sum()) ipdb.set_trace()
dataset_test = ConditionalGanDataset(path_test_real, load_sample_cgan_test, ['.PNG', '.png'], ['.PNG', '.png']) ### Transforms applied to data from batchgenerators.transforms import RandomCropTransform, Compose from batchgenerators.transforms.spatial_transforms import ResizeTransform,SpatialTransform transforms = Compose([ #SpatialTransform(patch_size=(1024, 1024), do_rotation=True, patch_center_dist_from_border=1024, border_mode_data='reflect', # border_mode_seg='reflect', angle_x=(args.rot_angle, args.rot_angle), angle_y=(0, 0), angle_z=(0, 0), # do_elastic_deform=False, order_data=1, order_seg=1) ResizeTransform((int(args.resize_size), int(args.resize_size)), order=1), RandomCropTransform((params.nested_get("image_size"), params.nested_get("image_size"))), ]) from delira.data_loading import BaseDataManager, SequentialSampler, RandomSampler manager_test = BaseDataManager(dataset_test, params.nested_get("batch_size"), transforms=transforms, sampler_cls=SequentialSampler, n_process_augmentation=1) import warnings warnings.simplefilter("ignore", UserWarning) # ignore UserWarnings raised by dependency code warnings.simplefilter("ignore", FutureWarning) # ignore FutureWarnings raised by dependency code
def get_default_augmentation_withEDT(dataloader_train, dataloader_val, patch_size, idx_of_edts, params=default_3D_augmentation_params, border_val_seg=-1, pin_memory=True, seeds_train=None, seeds_val=None): tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append( SpatialTransform(patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"))) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) """ ############################################################## ############################################################## Here we insert moving the EDT to a different key so that it does not get intensity transformed ############################################################## ############################################################## """ tr_transforms.append( AppendChannelsTransform("data", "bound", idx_of_edts, remove_from_input=True)) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get( "advanced_pyramid_augmentations") and not None and params.get( "advanced_pyramid_augmentations"): tr_transforms.append( ApplyRandomBinaryOperatorTransform(channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=0.4, key="data", strel_size=(1, 8))) tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=0.2, fill_with_other_class_p=0.0, dont_do_if_covers_more_than_X_percent=0.15)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) """ ############################################################## ############################################################## Here we insert moving the EDT to a different key ############################################################## ############################################################## """ val_transforms.append( AppendChannelsTransform("data", "bound", idx_of_edts, remove_from_input=True)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): val_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) val_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val
def run(self, img_data, seg_data): # Create a parser for the batchgenerators module data_generator = DataParser(img_data, seg_data) # Initialize empty transform list transforms = [] # Add mirror augmentation if self.mirror: aug_mirror = MirrorTransform(axes=self.config_mirror_axes) transforms.append(aug_mirror) # Add contrast augmentation if self.contrast: aug_contrast = ContrastAugmentationTransform( self.config_contrast_range, preserve_range=True, per_channel=True, p_per_sample=self.config_p_per_sample) transforms.append(aug_contrast) # Add brightness augmentation if self.brightness: aug_brightness = BrightnessMultiplicativeTransform( self.config_brightness_range, per_channel=True, p_per_sample=self.config_p_per_sample) transforms.append(aug_brightness) # Add gamma augmentation if self.gamma: aug_gamma = GammaTransform(self.config_gamma_range, invert_image=False, per_channel=True, retain_stats=True, p_per_sample=self.config_p_per_sample) transforms.append(aug_gamma) # Add gaussian noise augmentation if self.gaussian_noise: aug_gaussian_noise = GaussianNoiseTransform( self.config_gaussian_noise_range, p_per_sample=self.config_p_per_sample) transforms.append(aug_gaussian_noise) # Add spatial transformations as augmentation # (rotation, scaling, elastic deformation) if self.rotations or self.scaling or self.elastic_deform or \ self.cropping: # Identify patch shape (full image or cropping) if self.cropping: patch_shape = self.cropping_patch_shape else: patch_shape = img_data[0].shape[0:-1] # Assembling the spatial transformation aug_spatial_transform = SpatialTransform( patch_shape, [i // 2 for i in patch_shape], do_elastic_deform=self.elastic_deform, alpha=self.config_elastic_deform_alpha, sigma=self.config_elastic_deform_sigma, do_rotation=self.rotations, angle_x=self.config_rotations_angleX, angle_y=self.config_rotations_angleY, angle_z=self.config_rotations_angleZ, do_scale=self.scaling, scale=self.config_scaling_range, border_mode_data='constant', border_cval_data=0, border_mode_seg='constant', border_cval_seg=0, order_data=3, order_seg=0, p_el_per_sample=self.config_p_per_sample, p_rot_per_sample=self.config_p_per_sample, p_scale_per_sample=self.config_p_per_sample, random_crop=self.cropping) # Append spatial transformation to transformation list transforms.append(aug_spatial_transform) # Compose the batchgenerators transforms all_transforms = Compose(transforms) # Assemble transforms into a augmentation generator augmentation_generator = SingleThreadedAugmenter( data_generator, all_transforms) # Perform the data augmentation x times (x = cycles) aug_img_data = None aug_seg_data = None for i in range(0, self.cycles): # Run the computation process for the data augmentations augmentation = next(augmentation_generator) # Access augmentated data from the batchgenerators data structure if aug_img_data is None and aug_seg_data is None: aug_img_data = augmentation["data"] aug_seg_data = augmentation["seg"] # Concatenate the new data augmentated data with the cached data else: aug_img_data = np.concatenate( (augmentation["data"], aug_img_data), axis=0) aug_seg_data = np.concatenate( (augmentation["seg"], aug_seg_data), axis=0) # Transform data from channel-first back to channel-last structure # Data structure channel-first 3D: (batch, channel, x, y, z) # Data structure channel-last 3D: (batch, x, y, z, channel) aug_img_data = np.moveaxis(aug_img_data, 1, -1) aug_seg_data = np.moveaxis(aug_seg_data, 1, -1) # Return augmentated image and segmentation data return aug_img_data, aug_seg_data