def test_infinite_multi_threaded(self): data = list(range(123)) num_workers = 3 dl = DummyDataLoader(deepcopy(data), 12, num_workers, 1, return_incomplete=True, shuffle=True, infinite=False) mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False) # this should raise a StopIteration with self.assertRaises(StopIteration): for i in range(1000): idx = next(mt) dl = DummyDataLoader(deepcopy(data), 12, num_workers, 1, return_incomplete=True, shuffle=True, infinite=True) mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False) # this should now not raise a StopIteration anymore for i in range(1000): idx = next(mt)
def get_generators(patch_size, batch_size, preprocess_func, output_reshape_func, num_validation, train_processes, train_cache, train_data_dir='data/train/'): """ Creates augmented batch loaders and generators for Keras for training and validation :param patch_size: input of network without batch_size dimension :param batch_size: :param preprocess_func: callable to preprocess data per sample :param output_reshape_func: callable to reshape preprocessed and augmented data per sample :param num_validation: number of samples to validate on :param train_processes: number of threads to load, preprocess and augment data :param train_cache: number of augmented samples to cache """ dirs = util.get_data_list(train_data_dir) labels = util.parse_labels_months() train_paths, validation_paths = util.train_validation_split(dirs, labels) # generate train batch loader train_data_loader = CTBatchLoader(train_paths, batch_size, patch_size, num_threads_in_multithreaded=1, preprocess_func=preprocess_func) train_transforms = get_train_transform(patch_size) train_data_generator = MultiThreadedAugmenter( train_data_loader, train_transforms, num_processes=train_processes, num_cached_per_queue=train_cache, seeds=None, pin_memory=False) # wrapper to be compatible with keras train_generator_keras = KerasGenerator( train_data_generator, output_reshapefunc=output_reshape_func) # generate validation batch loader valid_data_loader = CTBatchLoader(validation_paths, num_validation, patch_size, num_threads_in_multithreaded=1, preprocess_func=preprocess_func) valid_transforms = get_valid_transform(patch_size) valid_data_generator = MultiThreadedAugmenter(valid_data_loader, valid_transforms, num_processes=1, num_cached_per_queue=1, seeds=None, pin_memory=False) # wrapper to be compatible with keras valid_generator_keras = KerasGenerator(valid_data_generator, output_reshape_func, 1) return train_generator_keras, valid_generator_keras
def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, seed=None, num_processes=8, num_cached_per_queue=8 * 4, target_size=128, file_pattern='*.png', do_reshuffle=True, keys=None): data_loader = MedImageDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, seed=seed, file_pattern=file_pattern, keys=keys) self.data_loader = data_loader self.batch_size = batch_size #self.do_reshuffle = do_reshuffle self.number_of_slices = 1 self.transforms = get_transforms(mode=mode, target_size=target_size) self.augmenter = MultiThreadedAugmenter( data_loader, self.transforms, num_processes=num_processes, num_cached_per_queue=num_cached_per_queue, seeds=seed, shuffle=do_reshuffle) self.augmenter.restart()
def test_image_pipeline_and_pin_memory(self): ''' This just should not crash :return: ''' try: import torch except ImportError: '''dont test if torch is not installed''' return from batchgenerators.transforms import MirrorTransform, NumpyToTensor, TransposeAxesTransform, Compose tr_transforms = [] tr_transforms.append(MirrorTransform()) tr_transforms.append( TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5)) tr_transforms.append(NumpyToTensor(keys='data', cast_to='float')) composed = Compose(tr_transforms) dl = self.dl_images mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True) for _ in range(50): res = mt.next() assert isinstance(res['data'], torch.Tensor) assert res['data'].is_pinned() # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent # the success of the test but it does not look pretty) sleep(2)
def __next__(self): if self.n <= self.n_folds: # t0 = time.time() self._build_generator() # t1 = time.time() # print("\n\n\nDuration: {0} {1}\n\n\n".format(np.round(t1-t0,2), self.n)) if self.train_bg_augmenter is None: self.train_bg_augmenter = MultiThreadedAugmenter( self.train_batchgenerator, self.image_transform, self.num_workers, pin_memory=True) else: self.train_bg_augmenter.set_generator( self.train_batchgenerator) if self.valid_bg_augmenter is None: self.valid_bg_augmenter = MultiThreadedAugmenter( self.valid_batchgenerator, self.image_transform, self.num_workers, pin_memory=True) else: self.valid_bg_augmenter.set_generator( self.valid_batchgenerator) self.n += 1 return self.train_bg_augmenter, self.valid_bg_augmenter else: raise StopIteration
def _train_single_epoch(self, batchgen: MultiThreadedAugmenter, epoch): """ Trains the network a single epoch Parameters ---------- batchgen : MultiThreadedAugmenter Generator yielding the training batches epoch : int current epoch """ self.module.training = True n_batches = batchgen.generator.num_batches * batchgen.num_processes pbar = tqdm(enumerate(batchgen), unit=' batch', total=n_batches, desc='Epoch %d' % epoch) for batch_nr, batch in pbar: data_dict = batch _, _, _ = self.closure_fn(self.module, data_dict, optimizers=self.optimizers, losses=self.losses, metrics=self.metrics, fold=self.fold, batch_nr=batch_nr) batchgen._finish()
def get_next_loader(self, train=True, build_new=False): if self.n <= self.n_folds: # t0 = time.time() if build_new: self._build_generator() # t1 = time.time() # print("\n\n\nDuration: {0} {1}\n\n\n".format(np.round(t1-t0,2), self.n)) if self.bg_augmenter is None: if train: self.bg_augmenter = MultiThreadedAugmenter( self.train_batchgenerator, self.image_transform, self.num_workers, pin_memory=True) else: self.bg_augmenter = MultiThreadedAugmenter( self.valid_batchgenerator, self.image_transform, self.num_workers, pin_memory=True) else: if train: self.bg_augmenter.set_generator(self.train_batchgenerator) else: self.bg_augmenter.set_generator(self.valid_batchgenerator) if build_new: self.n += 1 return self.bg_augmenter else: return None
def test_return_incomplete_multi_threaded(self): data = list(range(123)) batch_size = 12 num_workers = 3 dl = DummyDataLoader(deepcopy(data), batch_size, num_workers, 1, return_incomplete=False, shuffle=False, infinite=False) mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False) # this should now not raise a StopIteration anymore total = 0 ctr = 0 for i in mt: ctr += 1 assert len(i) == batch_size total += batch_size self.assertTrue(total == 120) self.assertTrue(ctr == 10) dl = DummyDataLoader(deepcopy(data), batch_size, num_workers, 1, return_incomplete=True, shuffle=False, infinite=False) mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False) # this should now not raise a StopIteration anymore total = 0 ctr = 0 for i in mt: ctr += 1 total += len(i) self.assertTrue(total == 123) self.assertTrue(ctr == 11)
def create_data_gen_train(patient_data_train, BATCH_SIZE, num_classes, num_workers=5, num_cached_per_worker=2, do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.), do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi), do_scale=True, scale_range=(0.75, 1.25), seeds=None): if seeds is None: seeds = [None]*num_workers elif seeds == 'range': seeds = range(num_workers) else: assert len(seeds) == num_workers data_gen_train = BatchGenerator_2D(patient_data_train, BATCH_SIZE, num_batches=None, seed=False, PATCH_SIZE=(352, 352)) tr_transforms = [] tr_transforms.append(Mirror((2, 3))) tr_transforms.append(RndTransform(SpatialTransform((352, 352), list(np.array((352, 352))//2), do_elastic_transform, alpha, sigma, do_rotation, a_x, a_y, a_z, do_scale, scale_range, 'constant', 0, 3, 'constant', 0, 0, random_crop=False), prob=0.67, alternative_transform=RandomCropTransform((352, 352)))) tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), seg_channel=0, output_key='seg_onehot')) tr_composed = Compose(tr_transforms) tr_mt_gen = MultiThreadedAugmenter(data_gen_train, tr_composed, num_workers, num_cached_per_worker, seeds) tr_mt_gen.restart() return tr_mt_gen
def test_no_crash(self): """ This one should just not crash, that's all :return: """ dl = self.dl_images mt_dl = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None, False) for _ in range(20): _ = mt_dl.next()
def get_batchgen(self, seed=1): """ Create DataLoader and Batchgenerator Parameters ---------- seed : int seed for Random Number Generator Returns ------- MultiThreadedAugmenter Batchgenerator Raises ------ AssertionError :attr:`BaseDataManager.n_batches` is smaller than or equal to zero """ assert self.n_batches > 0 data_loader = self.data_loader_cls(self.dataset, batch_size=self.batch_size, num_batches=self.n_batches, seed=seed, sampler=self.sampler) return MultiThreadedAugmenter(data_loader, self.transforms, self.n_process_augmentation, num_cached_per_queue=2, seeds=self.n_process_augmentation * [seed])
def test_return_all_indices_multi_threaded_shuffle_True(self): data = list(range(123)) batch_sizes = [1, 3, 75, 12, 23] num_workers = 3 for b in batch_sizes: dl = DummyDataLoader(deepcopy(data), b, num_workers, 1, return_incomplete=True, shuffle=True, infinite=False) mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False) for _ in range(3): idx = [] for i in mt: idx += i assert len(idx) == len(data) assert not all([i == j for i, j in zip(idx, data)]) idx.sort() assert all([i == j for i, j in zip(idx, data)])
def load_dataset(self): train = get_file_list(self.args.data_path, self.args.train_ids_path) #val = get_file_list(self.args.data_path,self.args.valid_ids_path) train, val = get_split_deterministic(train, fold=0, num_splits=5, random_state=12345) shapes = [brats_dataloader.load_patient(i)[0].shape[1:] for i in train] max_shape = np.max(shapes, 0) max_shape = list(np.max((max_shape, self.args.patch_size), 0)) dataloader_train = brats_dataloader(train, self.args.batch_size, max_shape, self.args.num_threads, return_incomplete=True, infinite=False) tr_transforms = get_train_transform(self.args.patch_size) tr_gen = MultiThreadedAugmenter(dataloader_train, tr_transforms, num_processes=self.args.num_threads, num_cached_per_queue=3, seeds=None, pin_memory=False) self.num_batches_per_epoch = int( math.ceil(len(train) / self.args.batch_size)) self.train_data_loader = tr_gen self.val = val
def get_train_val_generators(fold): tr_keys, te_keys = get_split(fold, split_seed) train_data = {i: dataset[i] for i in tr_keys} val_data = {i: dataset[i] for i in te_keys} data_gen_train = create_data_gen_train( train_data, BATCH_SIZE, num_classes, INPUT_PATCH_SIZE, num_workers=num_workers, do_elastic_transform=True, alpha=(0., 350.), sigma=(14., 17.), do_rotation=True, a_x=(0, 2. * np.pi), a_y=(-0.000001, 0.00001), a_z=(-0.000001, 0.00001), do_scale=True, scale_range=(0.7, 1.3), seeds=workers_seeds) # new se has no brain mask data_gen_validation = BatchGenerator(val_data, BATCH_SIZE, num_batches=None, seed=False, PATCH_SIZE=INPUT_PATCH_SIZE) val_transforms = [] val_transforms.append( ConvertSegToOnehotTransform(range(4), 0, 'seg_onehot')) data_gen_validation = MultiThreadedAugmenter(data_gen_validation, Compose(val_transforms), 1, 2, [0]) return data_gen_train, data_gen_validation
def test_return_incomplete_multi_threaded(self): data = list(range(123)) batch_size = 12 num_workers = 3 dl = DummyDataLoader(deepcopy(data), batch_size, num_workers, 1, return_incomplete=False, shuffle=False, infinite=False) mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False) all_return = [] total = 0 ctr = 0 for i in mt: ctr += 1 assert len(i) == batch_size total += len(i) all_return += i self.assertTrue(total == 120) self.assertTrue(ctr == 10) self.assertTrue(len(np.unique(all_return)) == total) dl = DummyDataLoader(deepcopy(data), batch_size, num_workers, 1, return_incomplete=True, shuffle=False, infinite=False) mt = MultiThreadedAugmenter(dl, None, num_workers, 1, None, False) all_return = [] total = 0 ctr = 0 for i in mt: ctr += 1 total += len(i) all_return += i self.assertTrue(total == 123) self.assertTrue(ctr == 11) self.assertTrue(len(np.unique(all_return)) == len(data))
def test_image_pipeline(self): ''' This just should not crash :return: ''' from batchgenerators.transforms import MirrorTransform, TransposeAxesTransform, Compose tr_transforms = [] tr_transforms.append(MirrorTransform()) tr_transforms.append(TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5)) composed = Compose(tr_transforms) dl = self.dl_images mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, False) for _ in range(50): res = mt.next() # let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent # the success of the test but it does not look pretty) sleep(2)
def create_data_gen_train(patient_data_train, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE, contrast_range=(0.75, 1.5), gamma_range = (0.6, 2), num_workers=5, num_cached_per_worker=3, do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.), do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi), do_scale=True, scale_range=(0.75, 1.25), seeds=None): if seeds is None: seeds = [None]*num_workers elif seeds == 'range': seeds = range(num_workers) else: assert len(seeds) == num_workers data_gen_train = BatchGenerator3D_random_sampling(patient_data_train, BATCH_SIZE, num_batches=None, seed=False, patch_size=(160, 192, 160), convert_labels=True) tr_transforms = [] tr_transforms.append(DataChannelSelectionTransform([0, 1, 2, 3])) tr_transforms.append(GenerateBrainMaskTransform()) tr_transforms.append(MirrorTransform()) tr_transforms.append(SpatialTransform(INPUT_PATCH_SIZE, list(np.array(INPUT_PATCH_SIZE)//2.), do_elastic_deform=do_elastic_transform, alpha=alpha, sigma=sigma, do_rotation=do_rotation, angle_x=a_x, angle_y=a_y, angle_z=a_z, do_scale=do_scale, scale=scale_range, border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True)) tr_transforms.append(BrainMaskAwareStretchZeroOneTransform((-5, 5), True)) tr_transforms.append(ContrastAugmentationTransform(contrast_range, True)) tr_transforms.append(GammaTransform(gamma_range, False)) tr_transforms.append(BrainMaskAwareStretchZeroOneTransform(per_channel=True)) tr_transforms.append(BrightnessTransform(0.0, 0.1, True)) tr_transforms.append(SegChannelSelectionTransform([0])) tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), 0, "seg_onehot")) gen_train = MultiThreadedAugmenter(data_gen_train, Compose(tr_transforms), num_workers, num_cached_per_worker, seeds) gen_train.restart() return gen_train
def get_no_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1): """ use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation :param dataloader_train: :param dataloader_val: :param patch_size: :param params: :param border_val_seg: :return: """ tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=range(params.get('num_threads')), pin_memory=True) batchgenerator_train.restart() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) val_transforms.append(RenameTransform('seg', 'target', True)) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads')//2, 1), params.get("num_cached_per_thread"), seeds=range(max(params.get('num_threads')//2, 1)), pin_memory=True) batchgenerator_val.restart() return batchgenerator_train, batchgenerator_val
def test_order(self): """ Coordinating workers in a multiprocessing envrionment is difficult. We want DummyDL in a multithreaded environment to still give us the numbers from 0 to 99 in ascending order :return: """ dl = self.dl mt = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None, False) res = [] for i in mt: res.append(i) assert len(res) == 100 res_copy = deepcopy(res) res.sort() assert all((i == j for i, j in zip(res, res_copy))) assert all((i == j for i, j in zip(res, np.arange(0, 100))))
def get_data_augmenter(data, batch_size=1, mode=DataLoader.Mode.NORMAL, volumetric=True, normalization_range=None, vector_generator=None, input_shape=None, sample_count=1, transforms=None, threads=1, seed=None): transforms = [] if transforms is None else transforms threads = min(int(np.ceil(len(data) / batch_size)), threads) loader = DataLoader(data=data, batch_size=batch_size, mode=mode, volumetric=volumetric, normalization_range=normalization_range, vector_generator=vector_generator, input_shape=input_shape, sample_count=sample_count, number_of_threads_in_multithreaded=threads, seed=seed) transforms = transforms + [PrepareForTF()] return MultiThreadedAugmenter(loader, Compose(transforms), threads)
def main(): # --------- Parse arguments --------------------------------------------------------------- parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='./config_0.yaml', help='Path to the configuration file.') parser.add_argument('--dice', action='store_true') # be aware of this argument!!! parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') parser.add_argument('-i', '--inference', default='', type=str, metavar='PATH', help='run inference on data set and save results') # 1e-8 works well for lung masks but seems to prevent # rapid learning for nodule masks parser.add_argument('--no-cuda', action='store_true') parser.add_argument('--save') parser.add_argument('--seed', type=int, default=1) args = parser.parse_args() # ---------- get the config file(config.yaml) -------------------------------------------- config = get_config(args.config) args.cuda = not args.no_cuda and torch.cuda.is_available() args.save = os.path.join('./work', (datestr() + ' ' + config['filename'])) nll = True if config['dice']: nll = False weight_decay = config['weight_decay'] num_threads_for_kits19 = config['num_of_threads'] patch_size = (160, 160, 128) num_batch_per_epoch = config['num_batch_per_epoch'] setproctitle.setproctitle(args.save) start_epoch = 1 # -------- Record best kidney segmentation dice ------------------------------------------- best_tk = 0.0 torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) print("build vnet") # Embed attention module model = vnet.VNet(elu=False, nll=nll, attention=config['attention'], nclass=3) # mark batch_size = config['ngpu'] * config['batchSz'] save_iter = config['model_save_iter'] # batch_size = args.ngpu*args.batchSz gpu_ids = range(config['ngpu']) # print(gpu_ids) model.apply(weights_init) # ------- Resume training from saved model ----------------------------------------------- if config['resume']: if os.path.isfile(config['resume']): print("=> loading checkpoint '{}'".format(config['resume'])) checkpoint = torch.load(config['resume']) # .tar files if config['resume'].endswith('.tar'): # print(checkpoint, "tar") start_epoch = checkpoint['epoch'] best_tk = checkpoint['best_tk'] checkpoint_model = checkpoint['model_state_dict'] model.load_state_dict( {k.replace('module.', ''): v for k, v in checkpoint_model.items()}) # .pkl files for the whole model else: # print(checkpoint, "pkl") model.load_state_dict(checkpoint.state_dict()) print("=> loaded checkpoint (epoch {})".format( checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(config['resume'])) exit(-1) else: pass # ------- Which loss function to use ------------------------------------------------------ if nll: training = train_bg validate = test_bg # class_balance = True else: training = train_bg_dice validate = test_bg_dice # class_balance = False # ----------------------------------------------------------------------------------------- print(' + Number of params: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) # -------- Set on GPU --------------------------------------------------------------------- if args.cuda: model = model.cuda() if os.path.exists(args.save): shutil.rmtree(args.save) # create the output directory os.makedirs(args.save) # save the config file to the output folder shutil.copy(args.config, os.path.join(args.save, 'config.yaml')) # kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} # ------ Load Training and Validation set -------------------------------------------- preprocessed_folders = "/home/data_share/npy_data/" patients = get_list_of_patients( preprocessed_data_folder=preprocessed_folders) # split num_split cross-validation sets # train, val = get_split_deterministic( # patients, fold=0, num_splits=5, random_state=12345) train, val = patients[0:147], patients[147:189] # VALIDATION DATA CANNOT BE LOADED IN CASE DUE TO THE LARGE SHAPE... # PRINT VALIDATION CASES FOR LATER TEST USE!! print("Validation cases:\n", val) # set max shape for validation set shapes = [Kits2019DataLoader3D.load_patient( i)[0].shape[1:] for i in val] max_shape = np.max(shapes, 0) max_shape = np.max((max_shape, patch_size), 0) # data loading + augmentation dataloader_train = Kits2019DataLoader3D( train, batch_size, patch_size, num_threads_for_kits19) dataloader_validation = Kits2019DataLoader3D( val, batch_size * 2, patch_size, num_threads_for_kits19) tr_transforms = get_train_transform(patch_size, prob=config['prob']) # whether to use single/multiThreadedAugmenter ------------------------------------------ if num_threads_for_kits19 > 1: tr_gen = MultiThreadedAugmenter(dataloader_train, tr_transforms, num_processes=num_threads_for_kits19, num_cached_per_queue=3,seeds=None, pin_memory=True) val_gen = MultiThreadedAugmenter(dataloader_validation, None, num_processes=max(1, num_threads_for_kits19//2), num_cached_per_queue=1, seeds=None, pin_memory=False) tr_gen.restart() val_gen.restart() else: tr_gen = SingleThreadedAugmenter(dataloader_train, transform=tr_transforms) val_gen = SingleThreadedAugmenter(dataloader_validation, transform=None) # ------- Set learning rate scheduler ---------------------------------------------------- lr_schdl = lr_scheduler.LR_Scheduler(mode=config['lr_policy'], base_lr=config['lr'], num_epochs=config['nEpochs'], iters_per_epoch=num_batch_per_epoch, lr_step=config['step_size'], warmup_epochs=config['warmup_epochs']) # ------ Choose Optimizer ---------------------------------------------------------------- if config['opt'] == 'sgd': optimizer = optim.SGD(model.parameters(), lr=config['lr'], momentum=0.99, weight_decay=weight_decay) elif config['opt'] == 'adam': optimizer = optim.Adam( model.parameters(), lr=config['lr'], weight_decay=weight_decay) elif config['opt'] == 'rmsprop': optimizer = optim.RMSprop( model.parameters(), lr=config['lr'], weight_decay=weight_decay) lr_plateu = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, verbose=True, threshold=1e-3, patience=5) # ------- Apex Mixed Precision Acceleration ---------------------------------------------- model, optimizer = amp.initialize(model, optimizer, opt_level="O1") model = nn.parallel.DataParallel(model, device_ids=gpu_ids) # ------- Save training data ------------------------------------------------------------- trainF = open(os.path.join(args.save, 'train.csv'), 'w') trainF.write('Epoch,Loss,Kidney_Dice,Tumor_Dice\n') testF = open(os.path.join(args.save, 'test.csv'), 'w') testF.write('Epoch,Loss,Kidney_Dice,Tumor_Dice\n ') # ------- Training Pipeline -------------------------------------------------------------- for epoch in range(start_epoch, config['nEpochs'] + start_epoch): torch.cuda.empty_cache() training(args, epoch, model, tr_gen, optimizer, trainF, config, lr_schdl) torch.cuda.empty_cache() print('==>lr decay to:', optimizer.param_groups[0]['lr']) print('testing validation set...') composite_dice = validate(args, epoch, model, val_gen, optimizer, testF, config, lr_plateu) torch.cuda.empty_cache() # save model with best result and routinely if composite_dice > best_tk or epoch % config['model_save_iter'] == 0: # model_name = 'vnet_epoch_step1_' + str(epoch) + '.pkl' model_name = 'vnet_step1_' + str(epoch) + '.tar' # torch.save(model, os.path.join(args.save, model_name)) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_tk': best_tk }, os.path.join(args.save, model_name)) best_tk = composite_dice # ---------------------------------------------------------------------------------------- trainF.close() testF.close()
def get_no_augmentation(dataloader_train, dataloader_val, params=default_3D_augmentation_params, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None): """ use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation """ tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None tr_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=range(params.get('num_threads')), pin_memory=pin_memory) batchgenerator_train.restart() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None val_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=range(max(params.get('num_threads') // 2, 1)), pin_memory=pin_memory) batchgenerator_val.restart() return batchgenerator_train, batchgenerator_val
scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis"))) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), pin_memory=True) train_loader = batchgenerator_train train_batch = next(train_loader) print( train_batch.keys() ) # dict_keys(['data', 'target']), each with torch.Size([2, 1, 112, 240, 272]) print((train_batch['seg'] - train_batch['weak_label']).sum()) train_batch = next(train_loader) print((train_batch['seg'] - train_batch['weak_label']).sum()) ipdb.set_trace() print(sum(1 for _ in train_loader))
def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, pin_memory=True, seeds_train=None, seeds_val=None, regions=None): assert params.get('mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append(SpatialTransform( patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis") )) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get("mask_was_used_for_normalization") tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations") and not None and params.get( "cascade_do_cascade_augmentations"): tr_transforms.append(ApplyRandomBinaryOperatorTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get("cascade_random_binary_transform_p"), key="data", strel_size=params.get("cascade_random_binary_transform_size"))) tr_transforms.append(RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"), dont_do_if_covers_more_than_X_percent=params.get("cascade_remove_conn_comp_fill_with_other_class_p"))) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) # from batchgenerators.dataloading import SingleThreadedAugmenter # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) # import IPython;IPython.embed() batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val
def run(fold=0): print fold # ================================================================================================================= I_AM_FOLD = fold np.random.seed(65432) lasagne.random.set_rng(np.random.RandomState(98765)) sys.setrecursionlimit(2000) BATCH_SIZE = 2 INPUT_PATCH_SIZE =(128, 128, 128) num_classes=4 EXPERIMENT_NAME = "final" results_dir = os.path.join(paths.results_folder) if not os.path.isdir(results_dir): os.mkdir(results_dir) results_dir = os.path.join(results_dir, EXPERIMENT_NAME) if not os.path.isdir(results_dir): os.mkdir(results_dir) results_dir = os.path.join(results_dir, "fold%d"%I_AM_FOLD) if not os.path.isdir(results_dir): os.mkdir(results_dir) n_epochs = 300 lr_decay = np.float32(0.985) base_lr = np.float32(0.0005) n_batches_per_epoch = 100 n_test_batches = 10 n_feedbacks_per_epoch = 10. num_workers = 6 workers_seeds = [123, 1234, 12345, 123456, 1234567, 12345678] # ================================================================================================================= all_data = load_dataset() keys_sorted = np.sort(all_data.keys()) crossval_folds = KFold(len(all_data.keys()), n_folds=5, shuffle=True, random_state=123456) ctr = 0 for train_idx, test_idx in crossval_folds: print len(train_idx), len(test_idx) if ctr == I_AM_FOLD: train_keys = [keys_sorted[i] for i in train_idx] test_keys = [keys_sorted[i] for i in test_idx] break ctr += 1 train_data = {i:all_data[i] for i in train_keys} test_data = {i:all_data[i] for i in test_keys} data_gen_train = create_data_gen_train(train_data, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE, contrast_range=(0.75, 1.5), gamma_range = (0.8, 1.5), num_workers=num_workers, num_cached_per_worker=2, do_elastic_transform=True, alpha=(0., 1300.), sigma=(10., 13.), do_rotation=True, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi), do_scale=True, scale_range=(0.75, 1.25), seeds=workers_seeds) data_gen_validation = BatchGenerator3D_random_sampling(test_data, BATCH_SIZE, num_batches=None, seed=False, patch_size=INPUT_PATCH_SIZE, convert_labels=True) val_transforms = [] val_transforms.append(GenerateBrainMaskTransform()) val_transforms.append(BrainMaskAwareStretchZeroOneTransform(clip_range=(-5, 5), per_channel=True)) val_transforms.append(SegChannelSelectionTransform([0])) val_transforms.append(ConvertSegToOnehotTransform(range(4), 0, "seg_onehot")) val_transforms.append(DataChannelSelectionTransform([0, 1, 2, 3])) data_gen_validation = MultiThreadedAugmenter(data_gen_validation, Compose(val_transforms), 2, 2) x_sym = T.tensor5() seg_sym = T.matrix() net, seg_layer = build_net(x_sym, INPUT_PATCH_SIZE, num_classes, 4, 16, batch_size=BATCH_SIZE, do_instance_norm=True) output_layer_for_loss = net # add some weight decay l2_loss = lasagne.regularization.regularize_network_params(output_layer_for_loss, lasagne.regularization.l2) * 1e-5 # the distinction between prediction_train and test is important only if we enable dropout (batch norm/inst norm # does not use or save moving averages) prediction_train = lasagne.layers.get_output(output_layer_for_loss, x_sym, deterministic=False, batch_norm_update_averages=False, batch_norm_use_averages=False) loss_vec = - soft_dice_per_img_in_batch(prediction_train, seg_sym, BATCH_SIZE)[:, 1:] loss = loss_vec.mean() loss += l2_loss acc_train = T.mean(T.eq(T.argmax(prediction_train, axis=1), seg_sym.argmax(-1)), dtype=theano.config.floatX) prediction_test = lasagne.layers.get_output(output_layer_for_loss, x_sym, deterministic=True, batch_norm_update_averages=False, batch_norm_use_averages=False) loss_val = - soft_dice_per_img_in_batch(prediction_test, seg_sym, BATCH_SIZE)[:, 1:] loss_val = loss_val.mean() loss_val += l2_loss acc = T.mean(T.eq(T.argmax(prediction_test, axis=1), seg_sym.argmax(-1)), dtype=theano.config.floatX) # learning rate has to be a shared variable because we decrease it with every epoch params = lasagne.layers.get_all_params(output_layer_for_loss, trainable=True) learning_rate = theano.shared(base_lr) updates = lasagne.updates.adam(T.grad(loss, params), params, learning_rate=learning_rate, beta1=0.9, beta2=0.999) dc = hard_dice_per_img_in_batch(prediction_test, seg_sym.argmax(1), num_classes, BATCH_SIZE).mean(0) train_fn = theano.function([x_sym, seg_sym], [loss, acc_train, loss_vec], updates=updates) val_fn = theano.function([x_sym, seg_sym], [loss_val, acc, dc]) all_val_dice_scores=None all_training_losses = [] all_validation_losses = [] all_validation_accuracies = [] all_training_accuracies = [] val_dice_scores = [] epoch = 0 while epoch < n_epochs: if epoch == 100: data_gen_train = create_data_gen_train(train_data, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE, contrast_range=(0.85, 1.25), gamma_range = (0.8, 1.5), num_workers=6, num_cached_per_worker=2, do_elastic_transform=True, alpha=(0., 1000.), sigma=(10., 13.), do_rotation=True, a_x=(0., 2*np.pi), a_y=(-np.pi/8., np.pi/8.), a_z=(-np.pi/8., np.pi/8.), do_scale=True, scale_range=(0.85, 1.15), seeds=workers_seeds) if epoch == 175: data_gen_train = create_data_gen_train(train_data, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE, contrast_range=(0.9, 1.1), gamma_range = (0.85, 1.3), num_workers=6, num_cached_per_worker=2, do_elastic_transform=True, alpha=(0., 750.), sigma=(10., 13.), do_rotation=True, a_x=(0., 2*np.pi), a_y=(-0.00001, 0.00001), a_z=(-0.00001, 0.00001), do_scale=True, scale_range=(0.85, 1.15), seeds=workers_seeds) epoch_start_time = time.time() learning_rate.set_value(np.float32(base_lr* lr_decay**(epoch))) print "epoch: ", epoch, " learning rate: ", learning_rate.get_value() train_loss = 0 train_acc_tmp = 0 train_loss_tmp = 0 batch_ctr = 0 for data_dict in data_gen_train: data = data_dict["data"].astype(np.float32) seg = data_dict["seg_onehot"].astype(np.float32).transpose(0, 2, 3, 4, 1).reshape((-1, num_classes)) if batch_ctr != 0 and batch_ctr % int(np.floor(n_batches_per_epoch/n_feedbacks_per_epoch)) == 0: print "number of batches: ", batch_ctr, "/", n_batches_per_epoch print "training_loss since last update: ", \ train_loss_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch), " train accuracy: ", \ train_acc_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch) all_training_losses.append(train_loss_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch)) all_training_accuracies.append(train_acc_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch)) train_loss_tmp = 0 train_acc_tmp = 0 if len(val_dice_scores) > 0: all_val_dice_scores = np.concatenate(val_dice_scores, axis=0).reshape((-1, num_classes)) try: printLosses(all_training_losses, all_training_accuracies, all_validation_losses, all_validation_accuracies, os.path.join(results_dir, "%s.png" % EXPERIMENT_NAME), n_feedbacks_per_epoch, val_dice_scores=all_val_dice_scores, val_dice_scores_labels=["brain", "1", "2", "3", "4", "5"]) except: pass loss_vec, acc, l = train_fn(data, seg) loss = loss_vec.mean() train_loss += loss train_loss_tmp += loss train_acc_tmp += acc batch_ctr += 1 if batch_ctr >= n_batches_per_epoch: break all_training_losses.append(train_loss_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch)) all_training_accuracies.append(train_acc_tmp/np.floor(n_batches_per_epoch/n_feedbacks_per_epoch)) train_loss /= n_batches_per_epoch print "training loss average on epoch: ", train_loss val_loss = 0 accuracies = [] valid_batch_ctr = 0 all_dice = [] for data_dict in data_gen_validation: data = data_dict["data"].astype(np.float32) seg = data_dict["seg_onehot"].astype(np.float32).transpose(0, 2, 3, 4, 1).reshape((-1, num_classes)) w = np.zeros(num_classes, dtype=np.float32) w[np.unique(seg.argmax(-1))] = 1 loss, acc, dice = val_fn(data, seg) dice[w==0] = 2 all_dice.append(dice) val_loss += loss accuracies.append(acc) valid_batch_ctr += 1 if valid_batch_ctr >= n_test_batches: break all_dice = np.vstack(all_dice) dice_means = np.zeros(num_classes) for i in range(num_classes): dice_means[i] = all_dice[all_dice[:, i]!=2, i].mean() val_loss /= n_test_batches print "val loss: ", val_loss print "val acc: ", np.mean(accuracies), "\n" print "val dice: ", dice_means print "This epoch took %f sec" % (time.time()-epoch_start_time) val_dice_scores.append(dice_means) all_validation_losses.append(val_loss) all_validation_accuracies.append(np.mean(accuracies)) all_val_dice_scores = np.concatenate(val_dice_scores, axis=0).reshape((-1, num_classes)) try: printLosses(all_training_losses, all_training_accuracies, all_validation_losses, all_validation_accuracies, os.path.join(results_dir, "%s.png" % EXPERIMENT_NAME), n_feedbacks_per_epoch, val_dice_scores=all_val_dice_scores, val_dice_scores_labels=["brain", "1", "2", "3", "4", "5"]) except: pass with open(os.path.join(results_dir, "%s_Params.pkl" % (EXPERIMENT_NAME)), 'w') as f: cPickle.dump(lasagne.layers.get_all_param_values(output_layer_for_loss), f) with open(os.path.join(results_dir, "%s_allLossesNAccur.pkl"% (EXPERIMENT_NAME)), 'w') as f: cPickle.dump([all_training_losses, all_training_accuracies, all_validation_losses, all_validation_accuracies, val_dice_scores], f) epoch += 1
dataloader_train = BraTS2017DataLoader3D(train, batch_size, max_shape, 1) # during training I like to run a validation from time to time to see where I am standing. This is not a correct # validation because just like training this is patch-based but it's good enough. We don't do augmentation for the # validation, so patch_size is used as shape target here dataloader_validation = BraTS2017DataLoader3D(val, batch_size, patch_size, 1) tr_transforms = get_train_transform(patch_size) # finally we can create multithreaded transforms that we can actually use for training # we don't pin memory here because this is pytorch specific. tr_gen = MultiThreadedAugmenter( dataloader_train, tr_transforms, num_processes=num_threads_for_brats_example, num_cached_per_queue=3, seeds=None, pin_memory=False) # we need less processes for vlaidation because we dont apply transformations val_gen = MultiThreadedAugmenter(dataloader_validation, None, num_processes=max( 1, num_threads_for_brats_example // 2), num_cached_per_queue=1, seeds=None, pin_memory=False) # lets start the MultiThreadedAugmenter. This is not necessary but allows them to start generating training # batches while other things run in the main thread
shuffle=False, return_incomplete=True) dataloader_validation = BraTS2017DataLoader3D(val, batch_size, None, 1, infinite=False, shuffle=False, return_incomplete=True) tr_transforms = get_train_transform(patch_size) tr_gen = MultiThreadedAugmenter( dataloader_train, tr_transforms, num_processes=num_threads_for_brats_example, num_cached_per_queue=3, seeds=None, pin_memory=False) #tr_gen.restart() num_batches_per_epoch = 100 num_validation_batches_per_epoch = 20 num_epochs = 5 # let's run this to get a time on how long it takes time_per_epoch = [] start = time() for epoch in range(num_epochs): start_epoch = time()
val_dl = BRATSDataLoader(patients_val, batch_size=batch_size, patch_size=patch_size, in_channels=in_channels) #%% tr_transforms = get_train_transform(patch_size) #%% # finally we can create multithreaded transforms that we can actually use for training # we don't pin memory here because this is pytorch specific. tr_gen = MultiThreadedAugmenter( train_dl, tr_transforms, num_processes=4, # num_processes=4 num_cached_per_queue=3, seeds=None, pin_memory=False) # we need less processes for vlaidation because we dont apply transformations val_gen = MultiThreadedAugmenter( val_dl, None, num_processes=max(1, 4 // 2), # num_processes=max(1, 4 // 2) num_cached_per_queue=1, seeds=None, pin_memory=False) #%% tr_gen.restart() val_gen.restart()
def get_default_augmentation_withEDT(dataloader_train, dataloader_val, patch_size, idx_of_edts, params=default_3D_augmentation_params, border_val_seg=-1, pin_memory=True, seeds_train=None, seeds_val=None): tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append( SpatialTransform(patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"))) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) """ ############################################################## ############################################################## Here we insert moving the EDT to a different key so that it does not get intensity transformed ############################################################## ############################################################## """ tr_transforms.append( AppendChannelsTransform("data", "bound", idx_of_edts, remove_from_input=True)) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get( "advanced_pyramid_augmentations") and not None and params.get( "advanced_pyramid_augmentations"): tr_transforms.append( ApplyRandomBinaryOperatorTransform(channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=0.4, key="data", strel_size=(1, 8))) tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=0.2, fill_with_other_class_p=0.0, dont_do_if_covers_more_than_X_percent=0.15)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) """ ############################################################## ############################################################## Here we insert moving the EDT to a different key ############################################################## ############################################################## """ val_transforms.append( AppendChannelsTransform("data", "bound", idx_of_edts, remove_from_input=True)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): val_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) val_transforms.append(NumpyToTensor(['data', 'target', 'bound'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val
def get_insaneDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True): assert params.get( 'mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): ignore_axes = (0, ) tr_transforms.append(Convert3DTo2DTransform()) else: ignore_axes = None tr_transforms.append( SpatialTransform(patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis"))) if params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color # channel gets in the way tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15)) tr_transforms.append( GaussianBlurTransform((0.5, 1.5), different_sigma_per_channel=True, p_per_sample=0.2, p_per_channel=0.5)) tr_transforms.append( BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3), p_per_sample=0.15)) tr_transforms.append( ContrastAugmentationTransform(contrast_range=(0.65, 1.5), p_per_sample=0.15)) tr_transforms.append( SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True, p_per_channel=0.5, order_downsample=0, order_upsample=3, p_per_sample=0.25, ignore_axes=ignore_axes)) tr_transforms.append( GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=0.15)) # inverted gamma if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror") or params.get("mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) if params.get("mask_was_used_for_normalization") is not None: mask_was_used_for_normalization = params.get( "mask_was_used_for_normalization") tr_transforms.append( MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0)) tr_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): tr_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) if params.get("cascade_do_cascade_augmentations" ) and not None and params.get( "cascade_do_cascade_augmentations"): if params.get("cascade_random_binary_transform_p") > 0: tr_transforms.append( ApplyRandomBinaryOperatorTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), p_per_sample=params.get( "cascade_random_binary_transform_p"), key="data", strel_size=params.get( "cascade_random_binary_transform_size"))) if params.get("cascade_remove_conn_comp_p") > 0: tr_transforms.append( RemoveRandomConnectedComponentFromOneHotEncodingTransform( channel_idx=list( range(-len(params.get("all_segmentation_labels")), 0)), key="data", p_per_sample=params.get("cascade_remove_conn_comp_p"), fill_with_other_class_p=params.get( "cascade_remove_conn_comp_max_size_percent_threshold" ), dont_do_if_covers_more_than_X_percent=params.get( "cascade_remove_conn_comp_fill_with_other_class_p") )) tr_transforms.append(RenameTransform('seg', 'target', True)) if deep_supervision_scales is not None: if soft_ds: assert classes is not None tr_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) if params.get("move_last_seg_chanel_to_data") is not None and params.get( "move_last_seg_chanel_to_data"): val_transforms.append( MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if deep_supervision_scales is not None: if soft_ds: assert classes is not None val_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) return batchgenerator_train, batchgenerator_val