def test_HeMIS(p=0.0001): print('[INFO]: Starting test ... \n') training_transform_dict = { "Resample": { "wspace": 0.75, "hspace": 0.75 }, "CenterCrop": { "size": [48, 48] }, "NumpyToTensor": {} } transform_lst, _ = imed_transforms.prepare_transforms(training_transform_dict) roi_params = {"suffix": "_seg-manual", "slice_filter_roi": None} train_lst = ['sub-unf01'] contrasts = ['T1w', 'T2w', 'T2star'] print('[INFO]: Creating dataset ...\n') model_params = { "name": "HeMISUnet", "dropout_rate": 0.3, "bn_momentum": 0.9, "depth": 2, "in_channel": 1, "out_channel": 1, "missing_probability": 0.00001, "missing_probability_growth": 0.9, "contrasts": ["T1w", "T2w"], "ram": False, "path_hdf5": 'testing_data/mytestfile.hdf5', "csv_path": 'testing_data/hdf5.csv', "target_lst": ["T2w"], "roi_lst": ["T2w"] } contrast_params = { "contrast_lst": ['T1w', 'T2w', 'T2star'], "balance": {} } dataset = imed_adaptative.HDF5Dataset(root_dir=PATH_BIDS, subject_lst=train_lst, model_params=model_params, contrast_params=contrast_params, target_suffix=["_lesion-manual"], slice_axis=2, transform=transform_lst, metadata_choice=False, dim=2, slice_filter_fn=imed_loader_utils.SliceFilter(filter_empty_input=True, filter_empty_mask=True), roi_params=roi_params) dataset.load_into_ram(['T1w', 'T2w', 'T2star']) print("[INFO]: Dataset RAM status:") print(dataset.status) print("[INFO]: In memory Dataframe:") print(dataset.dataframe) # TODO # ds_train.filter_roi(nb_nonzero_thr=10) train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=1) model = models.HeMISUnet(contrasts=contrasts, depth=3, drop_rate=DROPOUT, bn_momentum=BN) print(model) cuda_available = torch.cuda.is_available() if cuda_available: torch.cuda.set_device(GPU_NUMBER) print("Using GPU number {}".format(GPU_NUMBER)) model.cuda() # Initialing Optimizer and scheduler step_scheduler_batch = False optimizer = optim.Adam(model.parameters(), lr=INIT_LR) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, N_EPOCHS) load_lst, reload_lst, pred_lst, opt_lst, schedul_lst, init_lst, gen_lst = [], [], [], [], [], [], [] for epoch in tqdm(range(1, N_EPOCHS + 1), desc="Training"): start_time = time.time() start_init = time.time() lr = scheduler.get_last_lr()[0] model.train() tot_init = time.time() - start_init init_lst.append(tot_init) num_steps = 0 start_gen = 0 for i, batch in enumerate(train_loader): if i > 0: tot_gen = time.time() - start_gen gen_lst.append(tot_gen) start_load = time.time() input_samples, gt_samples = imed_utils.unstack_tensors(batch["input"]), batch["gt"] print(batch["input_metadata"][0][0]["missing_mod"]) missing_mod = imed_training.get_metadata(batch["input_metadata"], model_params) print("Number of missing contrasts = {}." .format(len(input_samples) * len(input_samples[0]) - missing_mod.sum())) print("len input = {}".format(len(input_samples))) print("Batch = {}, {}".format(input_samples[0].shape, gt_samples[0].shape)) if cuda_available: var_input = imed_utils.cuda(input_samples) var_gt = imed_utils.cuda(gt_samples, non_blocking=True) else: var_input = input_samples var_gt = gt_samples tot_load = time.time() - start_load load_lst.append(tot_load) start_pred = time.time() preds = model(var_input, missing_mod) tot_pred = time.time() - start_pred pred_lst.append(tot_pred) start_opt = time.time() loss = - losses.DiceLoss()(preds, var_gt) optimizer.zero_grad() loss.backward() optimizer.step() if step_scheduler_batch: scheduler.step() num_steps += 1 tot_opt = time.time() - start_opt opt_lst.append(tot_opt) start_gen = time.time() start_schedul = time.time() if not step_scheduler_batch: scheduler.step() tot_schedul = time.time() - start_schedul schedul_lst.append(tot_schedul) start_reload = time.time() print("[INFO]: Updating Dataset") p = p ** (2 / 3) dataset.update(p=p) print("[INFO]: Reloading dataset") train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=1) tot_reload = time.time() - start_reload reload_lst.append(tot_reload) end_time = time.time() total_time = end_time - start_time tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time)) print('Mean SD init {} -- {}'.format(np.mean(init_lst), np.std(init_lst))) print('Mean SD load {} -- {}'.format(np.mean(load_lst), np.std(load_lst))) print('Mean SD reload {} -- {}'.format(np.mean(reload_lst), np.std(reload_lst))) print('Mean SD pred {} -- {}'.format(np.mean(pred_lst), np.std(pred_lst))) print('Mean SD opt {} -- {}'.format(np.mean(opt_lst), np.std(opt_lst))) print('Mean SD gen {} -- {}'.format(np.mean(gen_lst), np.std(gen_lst))) print('Mean SD scheduler {} -- {}'.format(np.mean(schedul_lst), np.std(schedul_lst)))
def train(model_params, dataset_train, dataset_val, training_params, log_directory, device, cuda_available=True, metric_fns=None, n_gif=0, resume_training=False, debugging=False): """Main command to train the network. Args: model_params (dict): Model's parameters. dataset_train (imed_loader): Training dataset. dataset_val (imed_loader): Validation dataset. training_params (dict): log_directory (str): Folder where log files, best and final models are saved. device (str): Indicates the CPU or GPU ID. cuda_available (bool): If True, CUDA is available. metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`. n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows predictions of a given slice from the validation sub-dataset. They are saved within the log directory. resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the log_directory) for resume training. This training state is saved everytime a new best model is saved in the log directory. debugging (bool): If True, extended verbosity and intermediate outputs. Returns: float, float, float, float: best_training_dice, best_training_loss, best_validation_dice, best_validation_loss. """ # Write the metrics, images, etc to TensorBoard format writer = SummaryWriter(log_dir=log_directory) # BALANCE SAMPLES AND PYTORCH LOADER conditions = all([ training_params["balance_samples"]["applied"], model_params["name"] != "HeMIS" ]) sampler_train, shuffle_train = get_sampler( dataset_train, conditions, training_params['balance_samples']['type']) train_loader = DataLoader(dataset_train, batch_size=training_params["batch_size"], shuffle=shuffle_train, pin_memory=True, sampler=sampler_train, collate_fn=imed_loader_utils.imed_collate, num_workers=0) gif_dict = {"image_path": [], "slice_id": [], "gif": []} if dataset_val: sampler_val, shuffle_val = get_sampler( dataset_val, conditions, training_params['balance_samples']['type']) val_loader = DataLoader(dataset_val, batch_size=training_params["batch_size"], shuffle=shuffle_val, pin_memory=True, sampler=sampler_val, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # Init GIF if n_gif > 0: indexes_gif = random.sample(range(len(dataset_val)), n_gif) for i_gif in range(n_gif): random_metadata = dict( dataset_val[indexes_gif[i_gif]]["input_metadata"][0]) gif_dict["image_path"].append(random_metadata['input_filenames']) gif_dict["slice_id"].append(random_metadata['slice_index']) gif_obj = imed_visualize.AnimatedGif( size=dataset_val[indexes_gif[i_gif]]["input"].numpy()[0].shape) gif_dict["gif"].append(copy.copy(gif_obj)) # GET MODEL if training_params["transfer_learning"]["retrain_model"]: print("\nLoading pretrained model's weights: {}.") print("\tFreezing the {}% first layers.".format( 100 - training_params["transfer_learning"]['retrain_fraction'] * 100.)) old_model_path = training_params["transfer_learning"]["retrain_model"] fraction = training_params["transfer_learning"]['retrain_fraction'] if 'reset' in training_params["transfer_learning"]: reset = training_params["transfer_learning"]['reset'] else: reset = True # Freeze first layers and reset last layers model = imed_models.set_model_for_retrain(old_model_path, retrain_fraction=fraction, map_location=device, reset=reset) else: print("\nInitialising model's weights from scratch.") model_class = getattr(imed_models, model_params["name"]) model = model_class(**model_params) if cuda_available: model.cuda() num_epochs = training_params["training_time"]["num_epochs"] # OPTIMIZER initial_lr = training_params["scheduler"]["initial_lr"] # filter out the parameters you are going to fine-tuning params_to_opt = filter(lambda p: p.requires_grad, model.parameters()) # Using Adam optimizer = optim.Adam(params_to_opt, lr=initial_lr) scheduler, step_scheduler_batch = get_scheduler( copy.copy(training_params["scheduler"]["lr_scheduler"]), optimizer, num_epochs) print("\nScheduler parameters: {}".format( training_params["scheduler"]["lr_scheduler"])) # Create dict containing gammas and betas after each FiLM layer. if 'film_layers' in model_params and any(model_params['film_layers']): gammas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)} betas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)} contrast_list = [] # Resume start_epoch = 1 resume_path = os.path.join(log_directory, "checkpoint.pth.tar") if resume_training: model, optimizer, gif_dict, start_epoch, val_loss_total_avg, scheduler, patience_count = load_checkpoint( model=model, optimizer=optimizer, gif_dict=gif_dict, scheduler=scheduler, fname=resume_path) # Individually transfer the optimizer parts # TODO: check if following lines are needed for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(device) # LOSS print("\nSelected Loss: {}".format(training_params["loss"]["name"])) print("\twith the parameters: {}".format([ training_params["loss"][k] for k in training_params["loss"] if k != "name" ])) loss_fct = get_loss_function(copy.copy(training_params["loss"])) loss_dice_fct = imed_losses.DiceLoss( ) # For comparison when another loss is used # INIT TRAINING VARIABLES best_training_dice, best_training_loss = float("inf"), float("inf") best_validation_loss, best_validation_dice = float("inf"), float("inf") patience_count = 0 begin_time = time.time() # EPOCH LOOP for epoch in tqdm(range(num_epochs), desc="Training", initial=start_epoch): epoch = epoch + start_epoch start_time = time.time() lr = scheduler.get_last_lr()[0] writer.add_scalar('learning_rate', lr, epoch) # Training loop ----------------------------------------------------------- model.train() train_loss_total, train_dice_loss_total = 0.0, 0.0 num_steps = 0 for i, batch in enumerate(train_loader): # GET SAMPLES if model_params["name"] == "HeMISUnet": input_samples = imed_utils.cuda( imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda(batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # MIXUP if training_params["mixup_alpha"]: input_samples, gt_samples = imed_mixup.mixup( input_samples, gt_samples, training_params["mixup_alpha"], debugging and epoch == 1, log_directory) # RUN MODEL if model_params["name"] == "HeMISUnet" or \ ('film_layers' in model_params and any(model_params['film_layers'])): metadata = get_metadata(batch["input_metadata"], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) # LOSS loss = loss_fct(preds, gt_samples) train_loss_total += loss.item() train_dice_loss_total += loss_dice_fct(preds, gt_samples).item() # UPDATE OPTIMIZER optimizer.zero_grad() loss.backward() optimizer.step() if step_scheduler_batch: scheduler.step() num_steps += 1 if i == 0 and debugging: imed_visualize.save_tensorboard_img( writer, epoch, "Train", input_samples, gt_samples, preds, is_three_dim=not model_params["is_2d"]) if not step_scheduler_batch: scheduler.step() # TRAINING LOSS train_loss_total_avg = train_loss_total / num_steps msg = "Epoch {} training loss: {:.4f}.".format(epoch, train_loss_total_avg) train_dice_loss_total_avg = train_dice_loss_total / num_steps if training_params["loss"]["name"] != "DiceLoss": msg += "\tDice training loss: {:.4f}.".format( train_dice_loss_total_avg) tqdm.write(msg) # CURRICULUM LEARNING if model_params["name"] == "HeMISUnet": # Increase the probability of a missing modality model_params["missing_probability"] **= model_params[ "missing_probability_growth"] dataset_train.update(p=model_params["missing_probability"]) # Validation loop ----------------------------------------------------- model.eval() val_loss_total, val_dice_loss_total = 0.0, 0.0 num_steps = 0 metric_mgr = imed_metrics.MetricManager(metric_fns) if dataset_val: for i, batch in enumerate(val_loader): with torch.no_grad(): # GET SAMPLES if model_params["name"] == "HeMISUnet": input_samples = imed_utils.cuda( imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda( batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # RUN MODEL if model_params["name"] == "HeMISUnet" or \ ('film_layers' in model_params and any(model_params['film_layers'])): metadata = get_metadata(batch["input_metadata"], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) # LOSS loss = loss_fct(preds, gt_samples) val_loss_total += loss.item() val_dice_loss_total += loss_dice_fct(preds, gt_samples).item() # Add frame to GIF for i_ in range(len(input_samples)): im, pr, met = input_samples[i_].cpu().numpy()[0], preds[i_].cpu().numpy()[0], \ batch["input_metadata"][i_][0] for i_gif in range(n_gif): if gif_dict["image_path"][i_gif] == met.__getitem__('input_filenames') and \ gif_dict["slice_id"][i_gif] == met.__getitem__('slice_index'): overlap = imed_visualize.overlap_im_seg(im, pr) gif_dict["gif"][i_gif].add(overlap, label=str(epoch)) num_steps += 1 # METRICS COMPUTATION gt_npy = gt_samples.cpu().numpy() preds_npy = preds.data.cpu().numpy() metric_mgr(preds_npy, gt_npy) if i == 0 and debugging: imed_visualize.save_tensorboard_img( writer, epoch, "Validation", input_samples, gt_samples, preds, is_three_dim=not model_params['is_2d']) if 'film_layers' in model_params and any(model_params['film_layers']) and debugging and \ epoch == num_epochs and i < int(len(dataset_val) / training_params["batch_size"]) + 1: # Store the values of gammas and betas after the last epoch for each batch gammas_dict, betas_dict, contrast_list = store_film_params( gammas_dict, betas_dict, contrast_list, batch['input_metadata'], model, model_params["film_layers"], model_params["depth"]) # METRICS COMPUTATION FOR CURRENT EPOCH val_loss_total_avg_old = val_loss_total_avg if epoch > 1 else None metrics_dict = metric_mgr.get_results() metric_mgr.reset() writer.add_scalars('Validation/Metrics', metrics_dict, epoch) val_loss_total_avg = val_loss_total / num_steps writer.add_scalars( 'losses', { 'train_loss': train_loss_total_avg, 'val_loss': val_loss_total_avg, }, epoch) msg = "Epoch {} validation loss: {:.4f}.".format( epoch, val_loss_total_avg) val_dice_loss_total_avg = val_dice_loss_total / num_steps if training_params["loss"]["name"] != "DiceLoss": msg += "\tDice validation loss: {:.4f}.".format( val_dice_loss_total_avg) tqdm.write(msg) end_time = time.time() total_time = end_time - start_time tqdm.write("Epoch {} took {:.2f} seconds.".format( epoch, total_time)) # UPDATE BEST RESULTS if val_loss_total_avg < best_validation_loss: # Save checkpoint state = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'gif_dict': gif_dict, 'scheduler': scheduler, 'patience_count': patience_count, 'validation_loss': val_loss_total_avg } torch.save(state, resume_path) # Save best model file model_path = os.path.join(log_directory, "best_model.pt") torch.save(model, model_path) # Update best scores best_validation_loss, best_training_loss = val_loss_total_avg, train_loss_total_avg best_validation_dice, best_training_dice = val_dice_loss_total_avg, train_dice_loss_total_avg # EARLY STOPPING if epoch > 1: val_diff = (val_loss_total_avg_old - val_loss_total_avg) * 100 / abs(val_loss_total_avg) if val_diff < training_params["training_time"][ "early_stopping_epsilon"]: patience_count += 1 if patience_count >= training_params["training_time"][ "early_stopping_patience"]: print( "Stopping training due to {} epochs without improvements" .format(patience_count)) break # Save final model final_model_path = os.path.join(log_directory, "final_model.pt") torch.save(model, final_model_path) if 'film_layers' in model_params and any( model_params['film_layers']) and debugging: save_film_params(gammas_dict, betas_dict, contrast_list, model_params["depth"], log_directory) # Save best model in log directory if os.path.isfile(resume_path): state = torch.load(resume_path) model_path = os.path.join(log_directory, "best_model.pt") model.load_state_dict(state['state_dict']) torch.save(model, model_path) # Save best model as ONNX in the model directory try: # Convert best model to ONNX and save it in model directory best_model_path = os.path.join( log_directory, model_params["folder_name"], model_params["folder_name"] + ".onnx") imed_utils.save_onnx_model(model, input_samples, best_model_path) except: # Save best model in model directory best_model_path = os.path.join(log_directory, model_params["folder_name"], model_params["folder_name"] + ".pt") torch.save(model, best_model_path) logger.warning( "Failed to save the model as '.onnx', saved it as '.pt': {}". format(best_model_path)) # Save GIFs gif_folder = os.path.join(log_directory, "gifs") if n_gif > 0 and not os.path.isdir(gif_folder): os.makedirs(gif_folder) for i_gif in range(n_gif): fname_out = gif_dict["image_path"][i_gif].split('/')[-3] + "__" fname_out += gif_dict["image_path"][i_gif].split('/')[-1].split( ".nii.gz")[0].split(gif_dict["image_path"][i_gif].split('/')[-3] + "_")[1] + "__" fname_out += str(gif_dict["slice_id"][i_gif]) + ".gif" path_gif_out = os.path.join(gif_folder, fname_out) gif_dict["gif"][i_gif].save(path_gif_out) writer.close() final_time = time.time() duration_time = final_time - begin_time print('begin ' + time.strftime('%H:%M:%S', time.localtime(begin_time)) + "| End " + time.strftime('%H:%M:%S', time.localtime(final_time)) + "| duration " + str(datetime.timedelta(seconds=duration_time))) return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss
def test_unet_time(download_data_testing_test_files, train_lst, target_lst, config): cuda_available, device = imed_utils.define_device(GPU_ID) loader_params = { "data_list": train_lst, "dataset_type": "training", "requires_undo": False, "path_data": [__data_testing_dir__], "target_suffix": target_lst, "extensions": [".nii.gz"], "slice_filter_params": {"filter_empty_mask": False, "filter_empty_input": True}, "patch_filter_params": {"filter_empty_mask": False, "filter_empty_input": False}, "slice_axis": "axial" } # Update loader_params with config loader_params.update(config) # Get Training dataset bids_df = BidsDataframe(loader_params, __tmp_dir__, derivatives=True) ds_train = imed_loader.load_dataset(bids_df, **loader_params) # Loader train_loader = DataLoader(ds_train, batch_size=1 if config["model_params"]["name"] == "Modified3DUNet" else BATCH_SIZE, shuffle=True, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=1) # MODEL model_params = loader_params["model_params"] model_params.update(MODEL_DEFAULT) # Get in_channel from contrast_lst if loader_params["multichannel"]: model_params["in_channel"] = len(loader_params["contrast_params"]["contrast_lst"]) else: model_params["in_channel"] = 1 # Get out_channel from target_suffix model_params["out_channel"] = len(loader_params["target_suffix"]) model_class = getattr(imed_models, model_params["name"]) model = model_class(**model_params) logger.debug(f"Training {model_params['name']}") if cuda_available: model.cuda() step_scheduler_batch = False # TODO: Add optim in pytest optimizer = optim.Adam(model.parameters(), lr=INIT_LR) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, N_EPOCHS) # TODO: add to pytest loss_fct = imed_losses.DiceLoss() load_lst, pred_lst, opt_lst, schedule_lst, init_lst, gen_lst = [], [], [], [], [], [] for epoch in tqdm(range(1, N_EPOCHS + 1), desc="Training"): start_time = time.time() start_init = time.time() model.train() tot_init = time.time() - start_init init_lst.append(tot_init) num_steps = 0 start_gen = 0 for i, batch in enumerate(train_loader): if i > 0: tot_gen = time.time() - start_gen gen_lst.append(tot_gen) start_load = time.time() input_samples = imed_utils.cuda(batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) tot_load = time.time() - start_load load_lst.append(tot_load) start_pred = time.time() if 'film_layers' in model_params: preds = model(input_samples, [[0, 1]]) else: preds = model(input_samples) tot_pred = time.time() - start_pred pred_lst.append(tot_pred) start_opt = time.time() loss = loss_fct(preds, gt_samples) optimizer.zero_grad() loss.backward() optimizer.step() if step_scheduler_batch: scheduler.step() num_steps += 1 tot_opt = time.time() - start_opt opt_lst.append(tot_opt) start_gen = time.time() start_schedule = time.time() if not step_scheduler_batch: scheduler.step() tot_schedule = time.time() - start_schedule schedule_lst.append(tot_schedule) end_time = time.time() total_time = end_time - start_time tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time)) logger.info(f"Mean SD init {np.mean(init_lst)} -- {np.std(init_lst)}") logger.info(f"Mean SD load {np.mean(load_lst)} -- {np.std(load_lst)}") logger.info(f"Mean SD pred {np.mean(pred_lst)} -- {np.std(pred_lst)}") logger.info(f"Mean SDopt {np.mean(opt_lst)} -- {np.std(opt_lst)}") logger.info(f"Mean SD gen {np.mean(gen_lst)} -- {np.std(gen_lst)}") logger.info(f"Mean SD scheduler {np.mean(schedule_lst)} -- {np.std(schedule_lst)}")
def test_unet_time(train_lst, target_lst, config): cuda_available, device = imed_utils.define_device(GPU_NUMBER) loader_params = { "data_list": train_lst, "dataset_type": "training", "requires_undo": False, "bids_path": PATH_BIDS, "target_suffix": target_lst, "slice_filter_params": { "filter_empty_mask": False, "filter_empty_input": True }, "slice_axis": "axial" } # Update loader_params with config loader_params.update(config) # Get Training dataset ds_train = imed_loader.load_dataset(**loader_params) # Loader train_loader = DataLoader(ds_train, batch_size=1 if config["model_params"]["name"] == "UNet3D" else BATCH_SIZE, shuffle=True, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=1) # MODEL model_params = loader_params["model_params"] model_params.update(MODEL_DEFAULT) # Get in_channel from contrast_lst if loader_params["multichannel"]: model_params["in_channel"] = len( loader_params["contrast_params"]["contrast_lst"]) else: model_params["in_channel"] = 1 # Get out_channel from target_suffix model_params["out_channel"] = len(loader_params["target_suffix"]) model_class = getattr(imed_models, model_params["name"]) model = model_class(**model_params) print("Training {}".format(model_params["name"])) if cuda_available: model.cuda() step_scheduler_batch = False # TODO: Add optim in pytest optimizer = optim.Adam(model.parameters(), lr=INIT_LR) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, N_EPOCHS) # TODO: add to pytest loss_fct = imed_losses.DiceLoss() load_lst, pred_lst, opt_lst, schedul_lst, init_lst, gen_lst = [], [], [], [], [], [] for epoch in tqdm(range(1, N_EPOCHS + 1), desc="Training"): start_time = time.time() start_init = time.time() model.train() tot_init = time.time() - start_init init_lst.append(tot_init) num_steps = 0 start_gen = 0 for i, batch in enumerate(train_loader): if i > 0: tot_gen = time.time() - start_gen gen_lst.append(tot_gen) start_load = time.time() input_samples = imed_utils.cuda(batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) tot_load = time.time() - start_load load_lst.append(tot_load) start_pred = time.time() preds = model(input_samples) tot_pred = time.time() - start_pred pred_lst.append(tot_pred) start_opt = time.time() loss = loss_fct(preds, gt_samples) optimizer.zero_grad() loss.backward() optimizer.step() if step_scheduler_batch: scheduler.step() num_steps += 1 tot_opt = time.time() - start_opt opt_lst.append(tot_opt) start_gen = time.time() start_schedul = time.time() if not step_scheduler_batch: scheduler.step() tot_schedul = time.time() - start_schedul schedul_lst.append(tot_schedul) end_time = time.time() total_time = end_time - start_time tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time)) print('Mean SD init {} -- {}'.format(np.mean(init_lst), np.std(init_lst))) print('Mean SD load {} -- {}'.format(np.mean(load_lst), np.std(load_lst))) print('Mean SD pred {} -- {}'.format(np.mean(pred_lst), np.std(pred_lst))) print('Mean SDopt {} -- {}'.format(np.mean(opt_lst), np.std(opt_lst))) print('Mean SD gen {} -- {}'.format(np.mean(gen_lst), np.std(gen_lst))) print('Mean SD scheduler {} -- {}'.format(np.mean(schedul_lst), np.std(schedul_lst)))
def train(model_params, dataset_train, dataset_val, training_params, path_output, device, wandb_params=None, cuda_available=True, metric_fns=None, n_gif=0, resume_training=False, debugging=False): """Main command to train the network. Args: model_params (dict): Model's parameters. dataset_train (imed_loader): Training dataset. dataset_val (imed_loader): Validation dataset. training_params (dict): path_output (str): Folder where log files, best and final models are saved. device (str): Indicates the CPU or GPU ID. cuda_available (bool): If True, CUDA is available. metric_fns (list): List of metrics, see :mod:`ivadomed.metrics`. n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows predictions of a given slice from the validation sub-dataset. They are saved within the output path. resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the path_output) for resume training. This training state is saved everytime a new best model is saved in the log directory. debugging (bool): If True, extended verbosity and intermediate outputs. Returns: float, float, float, float: best_training_dice, best_training_loss, best_validation_dice, best_validation_loss. """ # Write the metrics, images, etc to TensorBoard format writer = SummaryWriter(log_dir=path_output) # Enable wandb tracking if the required params are found in the config file and the api key is correct wandb_tracking = imed_utils.initialize_wandb(wandb_params) if wandb_tracking: # Collect all hyperparameters into a dictionary cfg = {**training_params, **model_params} # Get the actual project, group, and run names if they exist, else choose the temporary names as default project_name = wandb_params.get(WandbKW.PROJECT_NAME, "temp_project") group_name = wandb_params.get(WandbKW.GROUP_NAME, "temp_group") run_name = wandb_params.get(WandbKW.RUN_NAME, "temp_run") if project_name == "temp_project" or group_name == "temp_group" or run_name == "temp_run": logger.info( "{PROJECT/GROUP/RUN} name not found, initializing as {'temp_project'/'temp_group'/'temp_run'}" ) # Initialize WandB with metrics and hyperparameters wandb.init(project=project_name, group=group_name, name=run_name, config=cfg) # BALANCE SAMPLES AND PYTORCH LOADER conditions = all([ training_params[TrainingParamsKW.BALANCE_SAMPLES] [BalanceSamplesKW.APPLIED], model_params[ModelParamsKW.NAME] != "HeMIS" ]) sampler_train, shuffle_train = get_sampler( dataset_train, conditions, training_params[ TrainingParamsKW.BALANCE_SAMPLES][BalanceSamplesKW.TYPE]) train_loader = DataLoader( dataset_train, batch_size=training_params[TrainingParamsKW.BATCH_SIZE], shuffle=shuffle_train, pin_memory=True, sampler=sampler_train, collate_fn=imed_loader_utils.imed_collate, num_workers=0) gif_dict = {"image_path": [], "slice_id": [], "gif": []} if dataset_val: sampler_val, shuffle_val = get_sampler( dataset_val, conditions, training_params[ TrainingParamsKW.BALANCE_SAMPLES][BalanceSamplesKW.TYPE]) val_loader = DataLoader( dataset_val, batch_size=training_params[TrainingParamsKW.BATCH_SIZE], shuffle=shuffle_val, pin_memory=True, sampler=sampler_val, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # Init GIF if n_gif > 0: indexes_gif = random.sample(range(len(dataset_val)), n_gif) for i_gif in range(n_gif): random_metadata = dict( dataset_val[indexes_gif[i_gif]][MetadataKW.INPUT_METADATA][0]) gif_dict["image_path"].append( random_metadata[MetadataKW.INPUT_FILENAMES]) gif_dict["slice_id"].append( random_metadata[MetadataKW.SLICE_INDEX]) gif_obj = imed_visualize.AnimatedGif( size=dataset_val[indexes_gif[i_gif]]["input"].numpy()[0].shape) gif_dict["gif"].append(copy.copy(gif_obj)) # GET MODEL if training_params["transfer_learning"]["retrain_model"]: logger.info("Loading pretrained model's weights: {}.") logger.info("\tFreezing the {}% first layers.".format( 100 - training_params["transfer_learning"]['retrain_fraction'] * 100.)) old_model_path = training_params["transfer_learning"]["retrain_model"] fraction = training_params["transfer_learning"]['retrain_fraction'] if 'reset' in training_params["transfer_learning"]: reset = training_params["transfer_learning"]['reset'] else: reset = True # Freeze first layers and reset last layers model = imed_models.set_model_for_retrain(old_model_path, retrain_fraction=fraction, map_location=device, reset=reset) else: logger.info("Initialising model's weights from scratch.") model_class = getattr(imed_models, model_params[ModelParamsKW.NAME]) model = model_class(**model_params) if cuda_available: model.cuda() num_epochs = training_params["training_time"]["num_epochs"] # OPTIMIZER initial_lr = training_params["scheduler"]["initial_lr"] # filter out the parameters you are going to fine-tuning params_to_opt = filter(lambda p: p.requires_grad, model.parameters()) # Using Adam optimizer = optim.Adam(params_to_opt, lr=initial_lr) scheduler, step_scheduler_batch = get_scheduler( copy.copy(training_params["scheduler"]["lr_scheduler"]), optimizer, num_epochs) logger.info("Scheduler parameters: {}".format( training_params["scheduler"]["lr_scheduler"])) # Only call wandb methods if required params are found in the config file if wandb_tracking: # Logs gradients (at every log_freq steps) to the dashboard. wandb.watch(model, log="gradients", log_freq=wandb_params["log_grads_every"]) # Resume start_epoch = 1 resume_path = Path(path_output, "checkpoint.pth.tar") if resume_training: model, optimizer, gif_dict, start_epoch, val_loss_total_avg, scheduler, patience_count = load_checkpoint( model=model, optimizer=optimizer, gif_dict=gif_dict, scheduler=scheduler, fname=str(resume_path)) # Individually transfer the optimizer parts # TODO: check if following lines are needed for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(device) # LOSS logger.info("Selected Loss: {}".format(training_params["loss"]["name"])) logger.info("\twith the parameters: {}".format([ training_params["loss"][k] for k in training_params["loss"] if k != "name" ])) loss_fct = get_loss_function(copy.copy(training_params["loss"])) loss_dice_fct = imed_losses.DiceLoss( ) # For comparison when another loss is used # INIT TRAINING VARIABLES best_training_dice, best_training_loss = float("inf"), float("inf") best_validation_loss, best_validation_dice = float("inf"), float("inf") patience_count = 0 begin_time = time.time() # EPOCH LOOP for epoch in tqdm(range(num_epochs), desc="Training", initial=start_epoch): epoch = epoch + start_epoch start_time = time.time() lr = scheduler.get_last_lr()[0] writer.add_scalar('learning_rate', lr, epoch) if wandb_tracking: wandb.log({"learning_rate": lr}) # Training loop ----------------------------------------------------------- model.train() train_loss_total, train_dice_loss_total = 0.0, 0.0 num_steps = 0 for i, batch in enumerate(train_loader): # GET SAMPLES if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET: input_samples = imed_utils.cuda( imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda(batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # MIXUP if training_params["mixup_alpha"]: input_samples, gt_samples = imed_mixup.mixup( input_samples, gt_samples, training_params["mixup_alpha"], debugging and epoch == 1, path_output) # RUN MODEL if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \ (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])): metadata = get_metadata(batch[MetadataKW.INPUT_METADATA], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) # LOSS loss = loss_fct(preds, gt_samples) train_loss_total += loss.item() train_dice_loss_total += loss_dice_fct(preds, gt_samples).item() # UPDATE OPTIMIZER optimizer.zero_grad() loss.backward() optimizer.step() if step_scheduler_batch: scheduler.step() num_steps += 1 # Save image at every 50th step if debugging is true if i % 50 == 0 and debugging: imed_visualize.save_img( writer, epoch, "Train", input_samples, gt_samples, preds, wandb_tracking=wandb_tracking, is_three_dim=not model_params[ModelParamsKW.IS_2D]) if not step_scheduler_batch: scheduler.step() # TRAINING LOSS train_loss_total_avg = train_loss_total / num_steps msg = "Epoch {} training loss: {:.4f}.".format(epoch, train_loss_total_avg) train_dice_loss_total_avg = train_dice_loss_total / num_steps if training_params["loss"]["name"] != "DiceLoss": msg += "\tDice training loss: {:.4f}.".format( train_dice_loss_total_avg) logger.info(msg) tqdm.write(msg) # CURRICULUM LEARNING if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET: # Increase the probability of a missing modality model_params[ModelParamsKW.MISSING_PROBABILITY] **= model_params[ ModelParamsKW.MISSING_PROBABILITY_GROWTH] dataset_train.update( p=model_params[ModelParamsKW.MISSING_PROBABILITY]) # Validation loop ----------------------------------------------------- model.eval() val_loss_total, val_dice_loss_total = 0.0, 0.0 num_steps = 0 metric_mgr = imed_metrics.MetricManager(metric_fns) if dataset_val: for i, batch in enumerate(val_loader): with torch.no_grad(): # GET SAMPLES if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET: input_samples = imed_utils.cuda( imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda( batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # RUN MODEL if model_params[ModelParamsKW.NAME] == ConfigKW.HEMIS_UNET or \ (ModelParamsKW.FILM_LAYERS in model_params and any(model_params[ModelParamsKW.FILM_LAYERS])): metadata = get_metadata( batch[MetadataKW.INPUT_METADATA], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) # LOSS loss = loss_fct(preds, gt_samples) val_loss_total += loss.item() val_dice_loss_total += loss_dice_fct(preds, gt_samples).item() # Add frame to GIF for i_ in range(len(input_samples)): im, pr, met = input_samples[i_].cpu().numpy()[0], preds[i_].cpu().numpy()[0], \ batch[MetadataKW.INPUT_METADATA][i_][0] for i_gif in range(n_gif): if gif_dict["image_path"][i_gif] == met.__getitem__('input_filenames') and \ gif_dict["slice_id"][i_gif] == met.__getitem__('slice_index'): overlap = imed_visualize.overlap_im_seg(im, pr) gif_dict["gif"][i_gif].add(overlap, label=str(epoch)) num_steps += 1 # METRICS COMPUTATION gt_npy = gt_samples.cpu().numpy() preds_npy = preds.data.cpu().numpy() metric_mgr(preds_npy, gt_npy) # Save image at every 10th step if debugging is true if i % 50 == 0 and debugging: imed_visualize.save_img( writer, epoch, "Validation", input_samples, gt_samples, preds, wandb_tracking=wandb_tracking, is_three_dim=not model_params[ModelParamsKW.IS_2D]) # METRICS COMPUTATION FOR CURRENT EPOCH val_loss_total_avg_old = val_loss_total_avg if epoch > 1 else None metrics_dict = metric_mgr.get_results() metric_mgr.reset() val_loss_total_avg = val_loss_total / num_steps # log losses on Tensorboard by default writer.add_scalars('Validation/Metrics', metrics_dict, epoch) writer.add_scalars( 'losses', { 'train_loss': train_loss_total_avg, 'val_loss': val_loss_total_avg, }, epoch) # log on wandb if the corresponding dictionary is provided if wandb_tracking: wandb.log({"validation-metrics": metrics_dict}) wandb.log({ "losses": { 'train_loss': train_loss_total_avg, 'val_loss': val_loss_total_avg, } }) msg = "Epoch {} validation loss: {:.4f}.".format( epoch, val_loss_total_avg) val_dice_loss_total_avg = val_dice_loss_total / num_steps if training_params["loss"]["name"] != "DiceLoss": msg += "\tDice validation loss: {:.4f}.".format( val_dice_loss_total_avg) logger.info(msg) end_time = time.time() total_time = end_time - start_time msg_epoch = "Epoch {} took {:.2f} seconds.".format( epoch, total_time) logger.info(msg_epoch) # UPDATE BEST RESULTS if val_loss_total_avg < best_validation_loss: # Save checkpoint state = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'gif_dict': gif_dict, 'scheduler': scheduler, 'patience_count': patience_count, 'validation_loss': val_loss_total_avg } torch.save(state, resume_path) # Save best model file model_path = Path(path_output, "best_model.pt") torch.save(model, model_path) # Update best scores best_validation_loss, best_training_loss = val_loss_total_avg, train_loss_total_avg best_validation_dice, best_training_dice = val_dice_loss_total_avg, train_dice_loss_total_avg # EARLY STOPPING if epoch > 1: val_diff = (val_loss_total_avg_old - val_loss_total_avg) * 100 / abs(val_loss_total_avg) if val_diff < training_params["training_time"][ "early_stopping_epsilon"]: patience_count += 1 if patience_count >= training_params["training_time"][ "early_stopping_patience"]: logger.info( "Stopping training due to {} epochs without improvements" .format(patience_count)) break # Save final model final_model_path = Path(path_output, "final_model.pt") torch.save(model, final_model_path) # Save best model in output path if resume_path.is_file(): state = torch.load(resume_path) model_path = Path(path_output, "best_model.pt") model.load_state_dict(state['state_dict']) torch.save(model, model_path) # Save best model as ONNX in the model directory try: # Convert best model to ONNX and save it in model directory best_model_path = Path( path_output, model_params[ModelParamsKW.FOLDER_NAME], model_params[ModelParamsKW.FOLDER_NAME] + ".onnx") imed_utils.save_onnx_model(model, input_samples, str(best_model_path)) logger.info(f"Model saved as '.onnx': {best_model_path}") except Exception as e: logger.warning(f"Failed to save the model as '.onnx': {e}") # Save best model as PT in the model directory best_model_path = Path(path_output, model_params[ModelParamsKW.FOLDER_NAME], model_params[ModelParamsKW.FOLDER_NAME] + ".pt") torch.save(model, best_model_path) logger.info(f"Model saved as '.pt': {best_model_path}") # Save GIFs gif_folder = Path(path_output, "gifs") if n_gif > 0 and not gif_folder.is_dir(): gif_folder.mkdir(parents=True) for i_gif in range(n_gif): fname_out = gif_dict["image_path"][i_gif].split(os.sep)[-3] + "__" fname_out += gif_dict["image_path"][i_gif].split( os.sep)[-1].split(".nii.gz")[0].split( gif_dict["image_path"][i_gif].split(os.sep)[-3] + "_")[1] + "__" fname_out += str(gif_dict["slice_id"][i_gif]) + ".gif" path_gif_out = Path(gif_folder, fname_out) gif_dict["gif"][i_gif].save(str(path_gif_out)) writer.close() wandb.finish() final_time = time.time() duration_time = final_time - begin_time logger.info('begin ' + time.strftime('%H:%M:%S', time.localtime(begin_time)) + "| End " + time.strftime('%H:%M:%S', time.localtime(final_time)) + "| duration " + str(datetime.timedelta(seconds=duration_time))) return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss