def train_unet(args): """ Wrapper for reconstruction (U-Net) model training. :param args: Arguments object, containing training hyperparameters. """ args.exp_dir.mkdir(parents=True, exist_ok=True) writer = SummaryWriter(log_dir=args.exp_dir / 'summary') if args.resume: recon_model, args, start_epoch, optimizer = load_recon_model( args.recon_model_checkpoint, optim=True) else: model = build_reconstruction_model(args) if args.data_parallel: model = torch.nn.DataParallel(model) optimizer = build_optim(args, model.parameters()) best_dev_loss = 1e9 start_epoch = 0 logging.info(args) logging.info(model) # Save arguments for bookkeeping args_dict = { key: str(value) for key, value in args.__dict__.items() if not key.startswith('__') and not callable(key) } save_json(args.exp_dir / 'args.json', args_dict) train_loader = create_data_loader(args, 'train', shuffle=True) dev_loader = create_data_loader(args, 'val') display_loader = create_data_loader(args, 'val', display=True) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_step_size, args.lr_gamma) for epoch in range(start_epoch, args.num_epochs): train_loss, train_time = train_epoch(args, epoch, model, train_loader, optimizer, writer) dev_loss, dev_l1loss, dev_time = evaluate_loss(args, epoch, model, dev_loader, writer) visualize(args, epoch, model, display_loader, writer) scheduler.step() is_new_best = dev_loss < best_dev_loss best_dev_loss = min(best_dev_loss, dev_loss) save_model(args, args.exp_dir, epoch, model, optimizer, best_dev_loss, is_new_best) logging.info( f'Epoch = [{epoch:4d}/{args.num_epochs:4d}] TrainL1Loss = {train_loss:.4g} DevL1Loss = {dev_l1loss:.4g} ' f'DevLoss = {dev_loss:.4g} TrainTime = {train_time:.4f}s DevTime = {dev_time:.4f}s', ) writer.close()
def run_unet(args): """ Creates reconstructions of volumes in a dataset, and stores these to disk. :param args: Arguments object, containing evaluation hyperparameters. """ recon_args, model = load_recon_model(args) recon_args.data_path = args.data_path # in case model was trained on different machine data_loader = create_data_loader(recon_args, args.partition) model.eval() reconstructions = defaultdict(list) with torch.no_grad(): for _, _, _, input, _, gt_mean, gt_std, fnames, slices in data_loader: input = input.to(args.device) recons = model(input).squeeze(1).to('cpu') for i in range(recons.shape[0]): recons[i] = recons[i] * gt_std[i] + gt_mean[i] reconstructions[fnames[i]].append((slices[i].numpy(), recons[i].numpy())) reconstructions = { fname: np.stack([pred for _, pred in sorted(slice_preds)]) for fname, slice_preds in reconstructions.items() } args.predictions_path = args.recon_model_checkpoint.parent / 'reconstructions' save_reconstructions(reconstructions, args.predictions_path)
def test(args, recon_model): """ Performs evaluation of a pre-trained policy model. :param args: Argument object containing evaluation parameters. :param recon_model: reconstruction model. """ model, policy_args = load_policy_model( pathlib.Path(args.policy_model_checkpoint)) # Overwrite number of trajectories to test on policy_args.num_test_trajectories = args.num_test_trajectories if args.data_path is not None: # Overwrite data path if provided policy_args.data_path = args.data_path # Logging of policy model logging.info(args) logging.info(recon_model) logging.info(model) if args.wandb: wandb.config.update(args) wandb.watch(model, log='all') # Initialise summary writer writer = SummaryWriter(log_dir=policy_args.run_dir / 'summary') # Parameter counting logging.info( 'Reconstruction model parameters: total {}, of which {} trainable and {} untrainable' .format(count_parameters(recon_model), count_trainable_parameters(recon_model), count_untrainable_parameters(recon_model))) logging.info( 'Policy model parameters: total {}, of which {} trainable and {} untrainable' .format(count_parameters(model), count_trainable_parameters(model), count_untrainable_parameters(model))) # Create data loader test_loader = create_data_loader(policy_args, 'test', shuffle=False) test_data_range_dict = create_data_range_dict(policy_args, test_loader) do_and_log_evaluation(policy_args, -1, recon_model, model, test_loader, writer, 'Test', test_data_range_dict) writer.close()
def compute_gradients(args, epoch): param_dir = (f'epoch{epoch}_t{args.num_trajectories}' f'_runs{args.data_runs}_batch{args.batch_size}_bs{args.batches_step}') param_dir = args.policy_model_checkpoint.parent / param_dir param_dir.mkdir(parents=True, exist_ok=True) # Create storage path weight_path = param_dir / f'weight_grads_r{args.data_runs}.pkl' bias_path = param_dir / f'bias_grads_r{args.data_runs}.pkl' # Check if already computed (skip computing again if not args.force_computation) if weight_path.exists() and bias_path.exists() and not args.force_computation: print(f'Gradients already stored in: \n {weight_path}\n {bias_path}') return weight_path, bias_path, param_dir else: print('Exact job gradients not already stored. Checking same params but higher number of runs...') # Check if all gradients already stored in file for more runs for r in range(1, 11, 1): # Check up to 10 runs tmp_param_dir = (f'epoch{epoch}_t{args.num_trajectories}' f'_runs{r}_batch{args.batch_size}_bs{args.batches_step}') tmp_weight_path = args.policy_model_checkpoint.parent / tmp_param_dir / f'weight_grads_r{r}.pkl' tmp_bias_path = args.policy_model_checkpoint.parent / tmp_param_dir / f'bias_grads_r{r}.pkl' # If computation already stored for a higher number of runs, just grab the relevant bit and do not recompute. if tmp_weight_path.exists() and tmp_bias_path.exists() and not args.force_computation: print(f'Gradients up to run {r} already stored in: \n {tmp_weight_path}\n {tmp_bias_path}') with open(tmp_weight_path, 'rb') as f: full_weight_grads = pickle.load(f) with open(tmp_bias_path, 'rb') as f: full_bias_grads = pickle.load(f) # Get relevant bit for the number of runs requested assert len(full_weight_grads) % r == 0, 'Something went wrong with stored gradient shape.' grads_per_run = len(full_weight_grads) // r weight_grads = full_weight_grads[:grads_per_run * args.data_runs] bias_grads = full_bias_grads[:grads_per_run * args.data_runs] print(f" Saving only grads of run {args.data_runs} to: \n {param_dir}") with open(weight_path, 'wb') as f: pickle.dump(weight_grads, f) with open(bias_path, 'wb') as f: pickle.dump(bias_grads, f) return weight_path, bias_path, param_dir start_run = 0 weight_grads = [] bias_grads = [] # Check if some part of the gradients already computed for r in range(args.data_runs, 0, -1): tmp_param_dir = (f'epoch{epoch}_t{args.num_trajectories}' f'_runs{r}_batch{args.batch_size}_bs{args.batches_step}') tmp_weight_path = args.policy_model_checkpoint.parent / tmp_param_dir / f'weight_grads_r{r}.pkl' tmp_bias_path = args.policy_model_checkpoint.parent / tmp_param_dir / f'bias_grads_r{r}.pkl' # If part already computed, skip this part of the computation by setting start_run to the highest # computed run. Also load the weights. if tmp_weight_path.exists() and tmp_bias_path.exists() and not args.force_computation: print(f'Gradients up to run {r} already stored in: \n {tmp_weight_path}\n {tmp_bias_path}') with open(tmp_weight_path, 'rb') as f: weight_grads = pickle.load(f) with open(tmp_bias_path, 'rb') as f: bias_grads = pickle.load(f) start_run = r break model, policy_args, start_epoch, optimiser = load_policy_model(args.policy_model_checkpoint) add_base_args(args, policy_args) recon_args, recon_model = load_recon_model(policy_args) loader = create_data_loader(policy_args, 'train', shuffle=True) data_range_dict = create_data_range_dict(policy_args, loader) for r in range(start_run, args.data_runs): print(f"\n Run {r + 1} ...") cbatch = 0 tbs = 0 for it, data in enumerate(loader): # Randomly shuffled every time kspace, masked_kspace, mask, zf, gt, gt_mean, gt_std, fname, sl_idx = data cbatch += 1 tbs += mask.size(0) # shape after unsqueeze = batch x channel x columns x rows x complex kspace = kspace.unsqueeze(1).to(policy_args.device) masked_kspace = masked_kspace.unsqueeze(1).to(policy_args.device) mask = mask.unsqueeze(1).to(policy_args.device) # shape after unsqueeze = batch x channel x columns x rows zf = zf.unsqueeze(1).to(policy_args.device) gt = gt.unsqueeze(1).to(policy_args.device) gt_mean = gt_mean.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(policy_args.device) gt_std = gt_std.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(policy_args.device) unnorm_gt = gt * gt_std + gt_mean data_range = torch.stack([data_range_dict[vol] for vol in fname]) recons = recon_model(zf) if cbatch == 1: optimiser.zero_grad() action_list = [] logprob_list = [] reward_list = [] for step in range(policy_args.acquisition_steps): # Loop over acquisition steps loss, mask, masked_kspace, recons = compute_backprop_trajectory(policy_args, kspace, masked_kspace, mask, unnorm_gt, recons, gt_mean, gt_std, data_range, model, recon_model, step, action_list, logprob_list, reward_list) if cbatch == policy_args.batches_step: # Store gradients for SNR num = len(list(model.named_parameters())) for i, (name, param) in enumerate(model.named_parameters()): if i == num - 1: # biases of last layer weight_grads.append(param.grad.cpu().numpy()) elif i == num - 2: # weights of last layer bias_grads.append(param.grad.cpu().numpy()) cbatch = 0 print(f" - Adding grads of run {r + 1} to: \n {param_dir}") with open(weight_path, 'wb') as f: pickle.dump(weight_grads, f) with open(bias_path, 'wb') as f: pickle.dump(bias_grads, f) return weight_path, bias_path, param_dir
def train_and_eval(args, recon_args, recon_model): if args.resume: # Check that this works resumed = True new_run_dir = args.policy_model_checkpoint.parent data_path = args.data_path # In case models have been moved to a different machine, make sure the path to the recon model is the # path provided. recon_model_checkpoint = args.recon_model_checkpoint model, args, start_epoch, optimiser = load_policy_model(pathlib.Path( args.policy_model_checkpoint), optim=True) args.old_run_dir = args.run_dir args.old_recon_model_checkpoint = args.recon_model_checkpoint args.old_data_path = args.data_path args.recon_model_checkpoint = recon_model_checkpoint args.run_dir = new_run_dir args.data_path = data_path args.resume = True else: resumed = False # Improvement model to train model = build_policy_model(args) # Add mask parameters for training args = add_mask_params(args) if args.data_parallel: model = torch.nn.DataParallel(model) optimiser = build_optim(args, model.parameters()) start_epoch = 0 # Create directory to store results in savestr = '{}_res{}_al{}_accel{}_k{}_{}_{}'.format( args.dataset, args.resolution, args.acquisition_steps, args.accelerations, args.num_trajectories, datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), ''.join(choice(ascii_uppercase) for _ in range(5))) args.run_dir = args.exp_dir / savestr args.run_dir.mkdir(parents=True, exist_ok=False) args.resumed = resumed if args.wandb: allow_val_change = args.resumed # only allow changes if resumed: otherwise something is wrong. wandb.config.update(args, allow_val_change=allow_val_change) wandb.watch(model, log='all') # Logging logging.info(recon_model) logging.info(model) # Save arguments for bookkeeping args_dict = { key: str(value) for key, value in args.__dict__.items() if not key.startswith('__') and not callable(key) } save_json(args.run_dir / 'args.json', args_dict) # Initialise summary writer writer = SummaryWriter(log_dir=args.run_dir / 'summary') # Parameter counting logging.info( 'Reconstruction model parameters: total {}, of which {} trainable and {} untrainable' .format(count_parameters(recon_model), count_trainable_parameters(recon_model), count_untrainable_parameters(recon_model))) logging.info( 'Policy model parameters: total {}, of which {} trainable and {} untrainable' .format(count_parameters(model), count_trainable_parameters(model), count_untrainable_parameters(model))) if args.scheduler_type == 'step': scheduler = torch.optim.lr_scheduler.StepLR(optimiser, args.lr_step_size, args.lr_gamma) elif args.scheduler_type == 'multistep': if not isinstance(args.lr_multi_step_size, list): args.lr_multi_step_size = [args.lr_multi_step_size] scheduler = torch.optim.lr_scheduler.MultiStepLR( optimiser, args.lr_multi_step_size, args.lr_gamma) else: raise ValueError( "{} is not a valid scheduler choice ('step', 'multistep')".format( args.scheduler_type)) # Create data loaders train_loader = create_data_loader(args, 'train', shuffle=True) dev_loader = create_data_loader(args, 'val', shuffle=False) train_data_range_dict = create_data_range_dict(args, train_loader) dev_data_range_dict = create_data_range_dict(args, dev_loader) if not args.resume: if args.do_train_ssim: do_and_log_evaluation(args, -1, recon_model, model, train_loader, writer, 'Train', train_data_range_dict) do_and_log_evaluation(args, -1, recon_model, model, dev_loader, writer, 'Val', dev_data_range_dict) for epoch in range(start_epoch, args.num_epochs): train_loss, train_time = train_epoch(args, epoch, recon_model, model, train_loader, optimiser, writer, train_data_range_dict) logging.info( f'Epoch = [{epoch+1:3d}/{args.num_epochs:3d}] TrainLoss = {train_loss:.3g} TrainTime = {train_time:.2f}s ' ) if args.do_train_ssim: do_and_log_evaluation(args, epoch, recon_model, model, train_loader, writer, 'Train', train_data_range_dict) do_and_log_evaluation(args, epoch, recon_model, model, dev_loader, writer, 'Val', dev_data_range_dict) scheduler.step() save_policy_model(args, args.run_dir, epoch, model, optimiser) writer.close()
def main(args): """ Wrapper for running baseline models. :param args: Arguments object containing hyperparameters for baseline models. """ # For consistency args.val_batch_size = args.batch_size # Reconstruction model recon_args, recon_model = load_recon_model(args) # Add mask parameters for training args = add_mask_params(args) # Create directory to store results in savestr = '{}_res{}_al{}_accel{}_{}_{}_{}'.format( args.dataset, args.resolution, args.acquisition_steps, args.accelerations, args.model_type, datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), ''.join(choice(ascii_uppercase) for _ in range(5))) args.run_dir = args.exp_dir / savestr args.run_dir.mkdir(parents=True, exist_ok=False) if args.wandb: wandb.config.update(args) # Logging logging.info(args) logging.info(recon_model) logging.info('Model type: {}'.format(args.model_type)) # Save arguments for bookkeeping args_dict = { key: str(value) for key, value in args.__dict__.items() if not key.startswith('__') and not callable(key) } save_json(args.run_dir / 'args.json', args_dict) # Initialise summary writer writer = SummaryWriter(log_dir=args.run_dir / 'summary') if args.model_type == 'average_oracle': baseline_ssims, baseline_psnrs, baseline_time = run_average_oracle( args, recon_model) else: # Create data loader loader = create_data_loader(args, args.partition) data_range_dict = create_data_range_dict(args, loader) baseline_ssims, baseline_psnrs, baseline_time = run_baseline( args, recon_model, loader, data_range_dict) # Logging ssims_str = ", ".join( ["{}: {:.4f}".format(i, l) for i, l in enumerate(baseline_ssims)]) psnrs_str = ", ".join( ["{}: {:.4f}".format(i, l) for i, l in enumerate(baseline_psnrs)]) logging.info(f' SSIM = [{ssims_str}]') logging.info(f' PSNR = [{psnrs_str}]') logging.info(f' Time = {baseline_time:.2f}s') # For storing in wandb for epoch in range(args.num_epochs + 1): if args.wandb: wandb.log( { f'{args.partition}_ssims': {str(key): val for key, val in enumerate(baseline_ssims)} }, step=epoch) wandb.log( { f'{args.partition}_psnrs': {str(key): val for key, val in enumerate(baseline_psnrs)} }, step=epoch) writer.close()