def create_data_gen_train(patient_data_train, BATCH_SIZE, num_classes, num_workers=5, num_cached_per_worker=2, do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.), do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi), do_scale=True, scale_range=(0.75, 1.25), seeds=None): if seeds is None: seeds = [None]*num_workers elif seeds == 'range': seeds = range(num_workers) else: assert len(seeds) == num_workers data_gen_train = BatchGenerator_2D(patient_data_train, BATCH_SIZE, num_batches=None, seed=False, PATCH_SIZE=(352, 352)) tr_transforms = [] tr_transforms.append(Mirror((2, 3))) tr_transforms.append(RndTransform(SpatialTransform((352, 352), list(np.array((352, 352))//2), do_elastic_transform, alpha, sigma, do_rotation, a_x, a_y, a_z, do_scale, scale_range, 'constant', 0, 3, 'constant', 0, 0, random_crop=False), prob=0.67, alternative_transform=RandomCropTransform((352, 352)))) tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), seg_channel=0, output_key='seg_onehot')) tr_composed = Compose(tr_transforms) tr_mt_gen = MultiThreadedAugmenter(data_gen_train, tr_composed, num_workers, num_cached_per_worker, seeds) tr_mt_gen.restart() return tr_mt_gen
class MedImageDataSet(object): """ TODO """ def __init__(self, base_dir, mode="train", batch_size=16, num_batches=10000000, seed=None, num_processes=8, num_cached_per_queue=8 * 4, target_size=128, file_pattern='*.png', do_reshuffle=True, keys=None): data_loader = MedImageDataLoader(base_dir=base_dir, mode=mode, batch_size=batch_size, num_batches=num_batches, seed=seed, file_pattern=file_pattern, keys=keys) self.data_loader = data_loader self.batch_size = batch_size #self.do_reshuffle = do_reshuffle self.number_of_slices = 1 self.transforms = get_transforms(mode=mode, target_size=target_size) self.augmenter = MultiThreadedAugmenter( data_loader, self.transforms, num_processes=num_processes, num_cached_per_queue=num_cached_per_queue, seeds=seed, shuffle=do_reshuffle) self.augmenter.restart() def __len__(self): return len(self.data_loader) def __iter__(self): self.augmenter.renew() return self.augmenter def __next__(self): return next(self.augmenter)
def get_no_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1): """ use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation :param dataloader_train: :param dataloader_val: :param patch_size: :param params: :param border_val_seg: :return: """ tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=range(params.get('num_threads')), pin_memory=True) batchgenerator_train.restart() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) val_transforms.append(RenameTransform('seg', 'target', True)) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads')//2, 1), params.get("num_cached_per_thread"), seeds=range(max(params.get('num_threads')//2, 1)), pin_memory=True) batchgenerator_val.restart() return batchgenerator_train, batchgenerator_val
def create_data_gen_train(patient_data_train, INPUT_PATCH_SIZE, num_classes, BATCH_SIZE, contrast_range=(0.75, 1.5), gamma_range = (0.6, 2), num_workers=5, num_cached_per_worker=3, do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.), do_rotation=False, a_x=(0., 2*np.pi), a_y=(0., 2*np.pi), a_z=(0., 2*np.pi), do_scale=True, scale_range=(0.75, 1.25), seeds=None): if seeds is None: seeds = [None]*num_workers elif seeds == 'range': seeds = range(num_workers) else: assert len(seeds) == num_workers data_gen_train = BatchGenerator3D_random_sampling(patient_data_train, BATCH_SIZE, num_batches=None, seed=False, patch_size=(160, 192, 160), convert_labels=True) tr_transforms = [] tr_transforms.append(DataChannelSelectionTransform([0, 1, 2, 3])) tr_transforms.append(GenerateBrainMaskTransform()) tr_transforms.append(MirrorTransform()) tr_transforms.append(SpatialTransform(INPUT_PATCH_SIZE, list(np.array(INPUT_PATCH_SIZE)//2.), do_elastic_deform=do_elastic_transform, alpha=alpha, sigma=sigma, do_rotation=do_rotation, angle_x=a_x, angle_y=a_y, angle_z=a_z, do_scale=do_scale, scale=scale_range, border_mode_data='nearest', border_cval_data=0, order_data=3, border_mode_seg='constant', border_cval_seg=0, order_seg=0, random_crop=True)) tr_transforms.append(BrainMaskAwareStretchZeroOneTransform((-5, 5), True)) tr_transforms.append(ContrastAugmentationTransform(contrast_range, True)) tr_transforms.append(GammaTransform(gamma_range, False)) tr_transforms.append(BrainMaskAwareStretchZeroOneTransform(per_channel=True)) tr_transforms.append(BrightnessTransform(0.0, 0.1, True)) tr_transforms.append(SegChannelSelectionTransform([0])) tr_transforms.append(ConvertSegToOnehotTransform(range(num_classes), 0, "seg_onehot")) gen_train = MultiThreadedAugmenter(data_gen_train, Compose(tr_transforms), num_workers, num_cached_per_worker, seeds) gen_train.restart() return gen_train
def main(): # --------- Parse arguments --------------------------------------------------------------- parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='./config_0.yaml', help='Path to the configuration file.') parser.add_argument('--dice', action='store_true') # be aware of this argument!!! parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') parser.add_argument('-i', '--inference', default='', type=str, metavar='PATH', help='run inference on data set and save results') # 1e-8 works well for lung masks but seems to prevent # rapid learning for nodule masks parser.add_argument('--no-cuda', action='store_true') parser.add_argument('--save') parser.add_argument('--seed', type=int, default=1) args = parser.parse_args() # ---------- get the config file(config.yaml) -------------------------------------------- config = get_config(args.config) args.cuda = not args.no_cuda and torch.cuda.is_available() args.save = os.path.join('./work', (datestr() + ' ' + config['filename'])) nll = True if config['dice']: nll = False weight_decay = config['weight_decay'] num_threads_for_kits19 = config['num_of_threads'] patch_size = (160, 160, 128) num_batch_per_epoch = config['num_batch_per_epoch'] setproctitle.setproctitle(args.save) start_epoch = 1 # -------- Record best kidney segmentation dice ------------------------------------------- best_tk = 0.0 torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) print("build vnet") # Embed attention module model = vnet.VNet(elu=False, nll=nll, attention=config['attention'], nclass=3) # mark batch_size = config['ngpu'] * config['batchSz'] save_iter = config['model_save_iter'] # batch_size = args.ngpu*args.batchSz gpu_ids = range(config['ngpu']) # print(gpu_ids) model.apply(weights_init) # ------- Resume training from saved model ----------------------------------------------- if config['resume']: if os.path.isfile(config['resume']): print("=> loading checkpoint '{}'".format(config['resume'])) checkpoint = torch.load(config['resume']) # .tar files if config['resume'].endswith('.tar'): # print(checkpoint, "tar") start_epoch = checkpoint['epoch'] best_tk = checkpoint['best_tk'] checkpoint_model = checkpoint['model_state_dict'] model.load_state_dict( {k.replace('module.', ''): v for k, v in checkpoint_model.items()}) # .pkl files for the whole model else: # print(checkpoint, "pkl") model.load_state_dict(checkpoint.state_dict()) print("=> loaded checkpoint (epoch {})".format( checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(config['resume'])) exit(-1) else: pass # ------- Which loss function to use ------------------------------------------------------ if nll: training = train_bg validate = test_bg # class_balance = True else: training = train_bg_dice validate = test_bg_dice # class_balance = False # ----------------------------------------------------------------------------------------- print(' + Number of params: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) # -------- Set on GPU --------------------------------------------------------------------- if args.cuda: model = model.cuda() if os.path.exists(args.save): shutil.rmtree(args.save) # create the output directory os.makedirs(args.save) # save the config file to the output folder shutil.copy(args.config, os.path.join(args.save, 'config.yaml')) # kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} # ------ Load Training and Validation set -------------------------------------------- preprocessed_folders = "/home/data_share/npy_data/" patients = get_list_of_patients( preprocessed_data_folder=preprocessed_folders) # split num_split cross-validation sets # train, val = get_split_deterministic( # patients, fold=0, num_splits=5, random_state=12345) train, val = patients[0:147], patients[147:189] # VALIDATION DATA CANNOT BE LOADED IN CASE DUE TO THE LARGE SHAPE... # PRINT VALIDATION CASES FOR LATER TEST USE!! print("Validation cases:\n", val) # set max shape for validation set shapes = [Kits2019DataLoader3D.load_patient( i)[0].shape[1:] for i in val] max_shape = np.max(shapes, 0) max_shape = np.max((max_shape, patch_size), 0) # data loading + augmentation dataloader_train = Kits2019DataLoader3D( train, batch_size, patch_size, num_threads_for_kits19) dataloader_validation = Kits2019DataLoader3D( val, batch_size * 2, patch_size, num_threads_for_kits19) tr_transforms = get_train_transform(patch_size, prob=config['prob']) # whether to use single/multiThreadedAugmenter ------------------------------------------ if num_threads_for_kits19 > 1: tr_gen = MultiThreadedAugmenter(dataloader_train, tr_transforms, num_processes=num_threads_for_kits19, num_cached_per_queue=3,seeds=None, pin_memory=True) val_gen = MultiThreadedAugmenter(dataloader_validation, None, num_processes=max(1, num_threads_for_kits19//2), num_cached_per_queue=1, seeds=None, pin_memory=False) tr_gen.restart() val_gen.restart() else: tr_gen = SingleThreadedAugmenter(dataloader_train, transform=tr_transforms) val_gen = SingleThreadedAugmenter(dataloader_validation, transform=None) # ------- Set learning rate scheduler ---------------------------------------------------- lr_schdl = lr_scheduler.LR_Scheduler(mode=config['lr_policy'], base_lr=config['lr'], num_epochs=config['nEpochs'], iters_per_epoch=num_batch_per_epoch, lr_step=config['step_size'], warmup_epochs=config['warmup_epochs']) # ------ Choose Optimizer ---------------------------------------------------------------- if config['opt'] == 'sgd': optimizer = optim.SGD(model.parameters(), lr=config['lr'], momentum=0.99, weight_decay=weight_decay) elif config['opt'] == 'adam': optimizer = optim.Adam( model.parameters(), lr=config['lr'], weight_decay=weight_decay) elif config['opt'] == 'rmsprop': optimizer = optim.RMSprop( model.parameters(), lr=config['lr'], weight_decay=weight_decay) lr_plateu = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, verbose=True, threshold=1e-3, patience=5) # ------- Apex Mixed Precision Acceleration ---------------------------------------------- model, optimizer = amp.initialize(model, optimizer, opt_level="O1") model = nn.parallel.DataParallel(model, device_ids=gpu_ids) # ------- Save training data ------------------------------------------------------------- trainF = open(os.path.join(args.save, 'train.csv'), 'w') trainF.write('Epoch,Loss,Kidney_Dice,Tumor_Dice\n') testF = open(os.path.join(args.save, 'test.csv'), 'w') testF.write('Epoch,Loss,Kidney_Dice,Tumor_Dice\n ') # ------- Training Pipeline -------------------------------------------------------------- for epoch in range(start_epoch, config['nEpochs'] + start_epoch): torch.cuda.empty_cache() training(args, epoch, model, tr_gen, optimizer, trainF, config, lr_schdl) torch.cuda.empty_cache() print('==>lr decay to:', optimizer.param_groups[0]['lr']) print('testing validation set...') composite_dice = validate(args, epoch, model, val_gen, optimizer, testF, config, lr_plateu) torch.cuda.empty_cache() # save model with best result and routinely if composite_dice > best_tk or epoch % config['model_save_iter'] == 0: # model_name = 'vnet_epoch_step1_' + str(epoch) + '.pkl' model_name = 'vnet_step1_' + str(epoch) + '.tar' # torch.save(model, os.path.join(args.save, model_name)) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_tk': best_tk }, os.path.join(args.save, model_name)) best_tk = composite_dice # ---------------------------------------------------------------------------------------- trainF.close() testF.close()
def get_no_augmentation(dataloader_train, dataloader_val, params=default_3D_augmentation_params, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None): """ use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation """ tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None tr_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: tr_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=range(params.get('num_threads')), pin_memory=pin_memory) batchgenerator_train.restart() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) if deep_supervision_scales is not None: if soft_ds: assert classes is not None val_transforms.append( DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) else: val_transforms.append( DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=range(max(params.get('num_threads') // 2, 1)), pin_memory=pin_memory) batchgenerator_val.restart() return batchgenerator_train, batchgenerator_val
num_cached_per_queue=3, seeds=None, pin_memory=False) # we need less processes for vlaidation because we dont apply transformations val_gen = MultiThreadedAugmenter(dataloader_validation, None, num_processes=max( 1, num_threads_for_brats_example // 2), num_cached_per_queue=1, seeds=None, pin_memory=False) # lets start the MultiThreadedAugmenter. This is not necessary but allows them to start generating training # batches while other things run in the main thread tr_gen.restart() val_gen.restart() # now if this was a network training you would run epochs like this (remember tr_gen and val_gen generate # inifinite examples! Don't do "for batch in tr_gen:"!!!): num_batches_per_epoch = 10 num_validation_batches_per_epoch = 3 num_epochs = 5 # let's run this to get a time on how long it takes time_per_epoch = [] start = time() for epoch in range(num_epochs): start_epoch = time() for b in range(num_batches_per_epoch): batch = next(tr_gen) # do network training here with this batch
val=get_file_list(brats_preprocessed_folder,valid_ids_path) shapes = [brats_dataloader.load_patient(i)[0].shape[1:] for i in train] max_shape = np.max(shapes, 0) max_shape = list(np.max((max_shape, patch_size), 0)) dataloader_train = brats_dataloader(train, batch_size, max_shape, num_threads,return_incomplete=True) dataloader_validation = brats_dataloader(val,batch_size, None,1,infinite=False,shuffle=False,return_incomplete=True) tr_transforms = get_train_transform(patch_size) tr_gen = MultiThreadedAugmenter(dataloader_train, tr_transforms, num_processes=num_threads, num_cached_per_queue=3, seeds=None, pin_memory=False) tr_gen.restart() log=open(log_path,'w') log.write('epoch,loss,valid loss\n') min_loss=1000 num_batches_per_epoch=int(math.ceil(len(train)/batch_size)) num_validation_batches_per_epoch=int(math.ceil(len(val)/batch_size)) current_lr=lr for epoch in range(max_epoch): raw_loss=0 with trange(num_batches_per_epoch) as t:
def get_arteries_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params, border_val_seg=-1, seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None, soft_ds=False, classes=None, pin_memory=True, regions=None, use_nondetMultiThreadedAugmenter: bool = False): assert params.get( 'mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] if params.get("selected_data_channels") is not None: tr_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! # if params.get("dummy_2D") is not None and params.get("dummy_2D"): # ignore_axes = (0,) # tr_transforms.append(Convert3DTo2DTransform()) # else: # ignore_axes = None tr_transforms.append( SpatialTransform(patch_size, patch_center_dist_from_border=None, do_elastic_deform=False, do_rotation=False, do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=order_seg, random_crop=False, p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get( "independent_scale_factor_for_each_axis"))) if params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) if params.get("do_mirror") or params.get("mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) tr_transforms.append(RemoveLabelTransform(-1, 0)) # if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): # tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) # if params.get("cascade_do_cascade_augmentations") is not None and params.get( # "cascade_do_cascade_augmentations"): # if params.get("cascade_random_binary_transform_p") > 0: # tr_transforms.append(ApplyRandomBinaryOperatorTransform( # channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), # p_per_sample=params.get("cascade_random_binary_transform_p"), # key="data", # strel_size=params.get("cascade_random_binary_transform_size"), # p_per_label=params.get("cascade_random_binary_transform_p_per_label"))) # if params.get("cascade_remove_conn_comp_p") > 0: # tr_transforms.append( # RemoveRandomConnectedComponentFromOneHotEncodingTransform( # channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)), # key="data", # p_per_sample=params.get("cascade_remove_conn_comp_p"), # fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"), # dont_do_if_covers_more_than_X_percent=params.get( # "cascade_remove_conn_comp_fill_with_other_class_p"))) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) # if deep_supervision_scales is not None: # if soft_ds: # assert classes is not None # tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) # else: # tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', # output_key='target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter( dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms) # import IPython;IPython.embed() batchgenerator_train.restart() val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append( DataChannelSelectionTransform( params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append( SegChannelSelectionTransform(params.get("selected_seg_channels"))) # if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"): # val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data')) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append( ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) # if deep_supervision_scales is not None: # if soft_ds: # assert classes is not None # val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes)) # else: # val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, 0, input_key='target', # output_key='target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter( dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms) batchgenerator_val.restart() return batchgenerator_train, batchgenerator_val
def create_data_gen_train(patient_data_train, BATCH_SIZE, num_classes, patch_size, num_workers=5, num_cached_per_worker=2, do_elastic_transform=False, alpha=(0., 1300.), sigma=(10., 13.), do_rotation=False, a_x=(0., 2 * np.pi), a_y=(0., 2 * np.pi), a_z=(0., 2 * np.pi), do_scale=True, scale_range=(0.75, 1.25), seeds=None): if seeds is None: seeds = [None] * num_workers elif seeds == 'range': seeds = range(num_workers) else: assert len(seeds) == num_workers data_gen_train = BatchGenerator(patient_data_train, BATCH_SIZE, num_batches=None, seed=False, PATCH_SIZE=(10, 352, 352)) # train transforms tr_transforms = [] tr_transforms.append(MotionAugmentationTransform(0.1, 0, 20)) tr_transforms.append(MirrorTransform((3, 4))) tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append( RndTransform(SpatialTransform(patch_size[1:], 112, do_elastic_transform, alpha, sigma, do_rotation, a_x, a_y, a_z, do_scale, scale_range, 'constant', 0, 3, 'constant', 0, 0, random_crop=False), prob=0.67, alternative_transform=RandomCropTransform( patch_size[1:]))) tr_transforms.append(Convert2DTo3DTransform(patch_size)) tr_transforms.append( RndTransform(GammaTransform((0.85, 1.3), False), prob=0.5)) tr_transforms.append( RndTransform(GammaTransform((0.85, 1.3), True), prob=0.5)) tr_transforms.append(CutOffOutliersTransform(0.3, 99.7, True)) tr_transforms.append(ZeroMeanUnitVarianceTransform(True)) tr_transforms.append( ConvertSegToOnehotTransform(range(num_classes), 0, 'seg_onehot')) tr_composed = Compose(tr_transforms) tr_mt_gen = MultiThreadedAugmenter(data_gen_train, tr_composed, num_workers, num_cached_per_worker, seeds) tr_mt_gen.restart() return tr_mt_gen
def get_default_augmentation(dataloader_train, dataloader_val=None, params=None, patch_size=None, border_val_seg=-1, pin_memory=True, seeds_train=None, seeds_val=None, regions=None): assert params.get('mirror') is None, "old version of params, use new keyword do_mirror" tr_transforms = [] assert params is not None, "augmentation params expect to be not None" if params.get("selected_data_channels") is not None: tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!! if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert3DTo2DTransform()) tr_transforms.append(SpatialTransform( patch_size, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"), do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"), angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"), border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant", border_cval_seg=border_val_seg, order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"), p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"), independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis") )) if params.get("dummy_2D") is not None and params.get("dummy_2D"): tr_transforms.append(Convert2DTo3DTransform()) if params.get("do_gamma"): tr_transforms.append( GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"), p_per_sample=params["p_gamma"])) if params.get("do_mirror"): tr_transforms.append(MirrorTransform(params.get("mirror_axes"))) tr_transforms.append(RemoveLabelTransform(-1, 0)) tr_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) tr_transforms.append(NumpyToTensor(['data', 'target'], 'float')) tr_transforms = Compose(tr_transforms) batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'), params.get("num_cached_per_thread"), seeds=seeds_train, pin_memory=pin_memory) batchgenerator_train.restart() if dataloader_val is None: return batchgenerator_train, None val_transforms = [] val_transforms.append(RemoveLabelTransform(-1, 0)) if params.get("selected_data_channels") is not None: val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels"))) if params.get("selected_seg_channels") is not None: val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels"))) val_transforms.append(RenameTransform('seg', 'target', True)) if regions is not None: val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target')) val_transforms.append(NumpyToTensor(['data', 'target'], 'float')) val_transforms = Compose(val_transforms) batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1), params.get("num_cached_per_thread"), seeds=seeds_val, pin_memory=pin_memory) batchgenerator_val.restart() return batchgenerator_train, batchgenerator_val
def main(): # init log now = time.strftime('%Y%m%d-%H%M%S', time.localtime(time.time())) log_path = train_params['log_path'] if not os.path.isdir(log_path): os.makedirs(log_path) log = open(os.path.join(log_path, 'log_{}.txt'.format(now)), 'w') print_log('save path : {}'.format(log_path), log) # prepare dataset dataset = load_train_dataset(phase='train', data_list_path=file_paths['train_list']) da_dataset = load_train_dataset_d(phase='train', data_list_path=file_paths['test_list']) # augmentation aug_transforms = get_train_transform(model_params['patch_size']) # source domain src_data_gen = ISeg2019DataLoader3D(dataset, model_params['batch_size'], model_params['patch_size'], nb_modalities=2, num_threads_in_multithreaded=4) src_aug_gen = MultiThreadedAugmenter(src_data_gen, aug_transforms, num_processes=4, num_cached_per_queue=4) src_aug_gen.restart() # target domain tgt_data_gen = ISeg2019DataLoader3D_Unlabel(da_dataset, model_params['batch_size'], model_params['patch_size'], nb_modalities=2, num_threads_in_multithreaded=4) tgt_aug_gen = MultiThreadedAugmenter(tgt_data_gen, aug_transforms, num_processes=4, num_cached_per_queue=4) tgt_aug_gen.restart() # define network net = DenseNet_3D(ver=model_params['model_ver']) net_d = FCDiscriminator3D(num_classes=model_params['nb_classes'], ndf=32) # define loss seg_loss = torch.nn.CrossEntropyLoss() bce_loss = torch.nn.BCEWithLogitsLoss() # define optimizer optimizer = torch.optim.Adam(net.parameters(), lr=train_params['lr_rate'], weight_decay=train_params['weight_decay'], betas=(train_params['momentum'], 0.999)) optimizer.zero_grad() optimizer_d = torch.optim.Adam(net_d.parameters(), lr=train_params['lr_rate_d'], weight_decay=train_params['weight_decay'], betas=(0.9, 0.99)) optimizer_d.zero_grad() start_step = 0 best_dice = 0. if train_params['resume_path'] is not None: print_log("=======> loading checkpoint '{}'".format( train_params['resume_path']), log=log) checkpoint = torch.load(train_params['resume_path']) net.load_state_dict(checkpoint['model_state_dict']) print_log("=======> loaded checkpoint '{}'".format( train_params['resume_path']), log=log) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, train_params['lr_step_size'], train_params['lr_gamma']) scheduler_d = torch.optim.lr_scheduler.StepLR(optimizer_d, train_params['lr_step_size'], train_params['lr_gamma']) # start training net.cuda() net_d.cuda() seg_loss.cuda() bce_loss.cuda() net.train() net_d.train() source_label = 1 target_label = 0 for step in range(start_step, train_params['nb_iters']): loss_seg_value = 0. loss_adv_target_value = 0. loss_D_src_value = 0. loss_D_tgt_value = 0. optimizer.zero_grad() optimizer_d.zero_grad() for sub_iter in range(train_params['nb_accu_iters']): # train G for param in net_d.parameters(): param.requires_grad = False # train with source src_batch = next(src_aug_gen) src_input_img = torch.from_numpy(src_batch['data']).cuda() src_input_label = torch.from_numpy( np.squeeze(src_batch['seg'], axis=1).astype(np.int64)).cuda() src_seg_out = net(src_input_img) loss = seg_loss(src_seg_out, src_input_label) loss_seg = loss / train_params['nb_accu_iters'] loss_seg.backward() loss_seg_value += loss.data.cpu().numpy() # train with target tgt_batch = next(tgt_aug_gen) tgt_input_img = torch.from_numpy(tgt_batch['data']).cuda() tgt_seg_out = net(tgt_input_img) tgt_d_out = net_d(prob_2_entropy(F.softmax(tgt_seg_out, dim=1))) loss = bce_loss( tgt_d_out, Variable( torch.FloatTensor( tgt_d_out.data.size()).fill_(source_label)).cuda()) loss_adv_tgt = train_params[ 'lambda_adv_target'] * loss / train_params['nb_accu_iters'] loss_adv_tgt.backward() loss_adv_target_value += loss.data.cpu().numpy() # train D for param in net_d.parameters(): param.requires_grad = True # train with source src_seg_out = src_seg_out.detach() src_d_out = net_d(prob_2_entropy(F.softmax(src_seg_out, dim=1))) loss = bce_loss( src_d_out, Variable( torch.FloatTensor( src_d_out.data.size()).fill_(source_label)).cuda()) loss_d_src = loss / train_params['nb_accu_iters'] loss_d_src.backward() loss_D_src_value += loss.data.cpu().numpy() # train with target tgt_seg_out = tgt_seg_out.detach() tgt_d_out = net_d(prob_2_entropy(F.softmax(tgt_seg_out, dim=1))) loss = bce_loss( tgt_d_out, Variable( torch.FloatTensor( tgt_d_out.data.size()).fill_(target_label)).cuda()) loss_d_tgt = loss / train_params['nb_accu_iters'] loss_d_tgt.backward() loss_D_tgt_value += loss.data.cpu().numpy() optimizer.step() scheduler.step() optimizer_d.step() scheduler_d.step() log_str = 'step {}: lr:{:.8f}, lr_d:{:.8f}, loss_seg:{:.6f}, loss_adv:{:.6f}, loss_D_src:{:.6f}, loss_D_tgt:{:.6f}'\ .format(step, scheduler.get_lr()[0], scheduler_d.get_lr()[0], loss_seg_value, loss_adv_target_value, loss_D_src_value, loss_D_tgt_value) print_log(log_str, log) # val and save per N iterations if (step + 1) % train_params['snapshot_step_size'] == 0: net.eval() val_avg_dice1, val_avg_dice2, val_avg_dice3, val_avg_dice = validation( net, dataset) val_log_str = 'val step: val_avg_dice:{}, val_avg_dice1:{}, val_avg_dice2:{}, val_avg_dice3:{}' \ .format(val_avg_dice, val_avg_dice1, val_avg_dice2, val_avg_dice3) print_log(val_log_str, log) is_best = False if val_avg_dice > best_dice: best_dice = val_avg_dice is_best = True save_checkpoint({ 'model_state_dict': net.state_dict(), }, is_best, train_params['model_snapshot_path'], 'checkpoint-{}.pth'.format(step + 1), 'model_best.pth') save_checkpoint({ 'model_d': net_d.state_dict(), }, is_best, train_params['model_snapshot_path'], 'checkpoint-d-{}.pth'.format(step + 1), 'model_d_best.pth') net.train() log.close()