예제 #1
0
def train_unet(args):
    """
    Wrapper for reconstruction (U-Net) model training.

    :param args: Arguments object, containing training hyperparameters.
    """
    args.exp_dir.mkdir(parents=True, exist_ok=True)
    writer = SummaryWriter(log_dir=args.exp_dir / 'summary')

    if args.resume:
        recon_model, args, start_epoch, optimizer = load_recon_model(
            args.recon_model_checkpoint, optim=True)
    else:
        model = build_reconstruction_model(args)
        if args.data_parallel:
            model = torch.nn.DataParallel(model)
        optimizer = build_optim(args, model.parameters())
        best_dev_loss = 1e9
        start_epoch = 0
    logging.info(args)
    logging.info(model)

    # Save arguments for bookkeeping
    args_dict = {
        key: str(value)
        for key, value in args.__dict__.items()
        if not key.startswith('__') and not callable(key)
    }
    save_json(args.exp_dir / 'args.json', args_dict)

    train_loader = create_data_loader(args, 'train', shuffle=True)
    dev_loader = create_data_loader(args, 'val')
    display_loader = create_data_loader(args, 'val', display=True)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_step_size,
                                                args.lr_gamma)

    for epoch in range(start_epoch, args.num_epochs):
        train_loss, train_time = train_epoch(args, epoch, model, train_loader,
                                             optimizer, writer)
        dev_loss, dev_l1loss, dev_time = evaluate_loss(args, epoch, model,
                                                       dev_loader, writer)
        visualize(args, epoch, model, display_loader, writer)
        scheduler.step()

        is_new_best = dev_loss < best_dev_loss
        best_dev_loss = min(best_dev_loss, dev_loss)
        save_model(args, args.exp_dir, epoch, model, optimizer, best_dev_loss,
                   is_new_best)
        logging.info(
            f'Epoch = [{epoch:4d}/{args.num_epochs:4d}] TrainL1Loss = {train_loss:.4g} DevL1Loss = {dev_l1loss:.4g} '
            f'DevLoss = {dev_loss:.4g} TrainTime = {train_time:.4f}s DevTime = {dev_time:.4f}s',
        )
    writer.close()
예제 #2
0
def run_unet(args):
    """
    Creates reconstructions of volumes in a dataset, and stores these to disk.

    :param args: Arguments object, containing evaluation hyperparameters.
    """
    recon_args, model = load_recon_model(args)
    recon_args.data_path = args.data_path  # in case model was trained on different machine
    data_loader = create_data_loader(recon_args, args.partition)

    model.eval()
    reconstructions = defaultdict(list)
    with torch.no_grad():
        for _, _, _, input, _, gt_mean, gt_std, fnames, slices in data_loader:
            input = input.to(args.device)
            recons = model(input).squeeze(1).to('cpu')
            for i in range(recons.shape[0]):
                recons[i] = recons[i] * gt_std[i] + gt_mean[i]
                reconstructions[fnames[i]].append((slices[i].numpy(), recons[i].numpy()))

    reconstructions = {
        fname: np.stack([pred for _, pred in sorted(slice_preds)])
        for fname, slice_preds in reconstructions.items()
    }

    args.predictions_path = args.recon_model_checkpoint.parent / 'reconstructions'
    save_reconstructions(reconstructions, args.predictions_path)
예제 #3
0
def test(args, recon_model):
    """
    Performs evaluation of a pre-trained policy model.

    :param args: Argument object containing evaluation parameters.
    :param recon_model: reconstruction model.
    """
    model, policy_args = load_policy_model(
        pathlib.Path(args.policy_model_checkpoint))

    # Overwrite number of trajectories to test on
    policy_args.num_test_trajectories = args.num_test_trajectories
    if args.data_path is not None:  # Overwrite data path if provided
        policy_args.data_path = args.data_path

    # Logging of policy model
    logging.info(args)
    logging.info(recon_model)
    logging.info(model)
    if args.wandb:
        wandb.config.update(args)
        wandb.watch(model, log='all')
    # Initialise summary writer
    writer = SummaryWriter(log_dir=policy_args.run_dir / 'summary')

    # Parameter counting
    logging.info(
        'Reconstruction model parameters: total {}, of which {} trainable and {} untrainable'
        .format(count_parameters(recon_model),
                count_trainable_parameters(recon_model),
                count_untrainable_parameters(recon_model)))
    logging.info(
        'Policy model parameters: total {}, of which {} trainable and {} untrainable'
        .format(count_parameters(model), count_trainable_parameters(model),
                count_untrainable_parameters(model)))

    # Create data loader
    test_loader = create_data_loader(policy_args, 'test', shuffle=False)
    test_data_range_dict = create_data_range_dict(policy_args, test_loader)

    do_and_log_evaluation(policy_args, -1, recon_model, model, test_loader,
                          writer, 'Test', test_data_range_dict)

    writer.close()
예제 #4
0
def compute_gradients(args, epoch):
    param_dir = (f'epoch{epoch}_t{args.num_trajectories}'
                 f'_runs{args.data_runs}_batch{args.batch_size}_bs{args.batches_step}')
    param_dir = args.policy_model_checkpoint.parent / param_dir
    param_dir.mkdir(parents=True, exist_ok=True)

    # Create storage path
    weight_path = param_dir / f'weight_grads_r{args.data_runs}.pkl'
    bias_path = param_dir / f'bias_grads_r{args.data_runs}.pkl'
    # Check if already computed (skip computing again if not args.force_computation)
    if weight_path.exists() and bias_path.exists() and not args.force_computation:
        print(f'Gradients already stored in: \n    {weight_path}\n    {bias_path}')
        return weight_path, bias_path, param_dir
    else:
        print('Exact job gradients not already stored. Checking same params but higher number of runs...')

    # Check if all gradients already stored in file for more runs
    for r in range(1, 11, 1):  # Check up to 10 runs
        tmp_param_dir = (f'epoch{epoch}_t{args.num_trajectories}'
                         f'_runs{r}_batch{args.batch_size}_bs{args.batches_step}')
        tmp_weight_path = args.policy_model_checkpoint.parent / tmp_param_dir / f'weight_grads_r{r}.pkl'
        tmp_bias_path = args.policy_model_checkpoint.parent / tmp_param_dir / f'bias_grads_r{r}.pkl'
        # If computation already stored for a higher number of runs, just grab the relevant bit and do not recompute.
        if tmp_weight_path.exists() and tmp_bias_path.exists() and not args.force_computation:
            print(f'Gradients up to run {r} already stored in: \n    {tmp_weight_path}\n    {tmp_bias_path}')
            with open(tmp_weight_path, 'rb') as f:
                full_weight_grads = pickle.load(f)
            with open(tmp_bias_path, 'rb') as f:
                full_bias_grads = pickle.load(f)

            # Get relevant bit for the number of runs requested
            assert len(full_weight_grads) % r == 0, 'Something went wrong with stored gradient shape.'
            grads_per_run = len(full_weight_grads) // r
            weight_grads = full_weight_grads[:grads_per_run * args.data_runs]
            bias_grads = full_bias_grads[:grads_per_run * args.data_runs]

            print(f" Saving only grads of run {args.data_runs} to: \n       {param_dir}")
            with open(weight_path, 'wb') as f:
                pickle.dump(weight_grads, f)
            with open(bias_path, 'wb') as f:
                pickle.dump(bias_grads, f)
            return weight_path, bias_path, param_dir

    start_run = 0
    weight_grads = []
    bias_grads = []

    # Check if some part of the gradients already computed
    for r in range(args.data_runs, 0, -1):
        tmp_param_dir = (f'epoch{epoch}_t{args.num_trajectories}'
                         f'_runs{r}_batch{args.batch_size}_bs{args.batches_step}')
        tmp_weight_path = args.policy_model_checkpoint.parent / tmp_param_dir / f'weight_grads_r{r}.pkl'
        tmp_bias_path = args.policy_model_checkpoint.parent / tmp_param_dir / f'bias_grads_r{r}.pkl'
        # If part already computed, skip this part of the computation by setting start_run to the highest
        # computed run. Also load the weights.
        if tmp_weight_path.exists() and tmp_bias_path.exists() and not args.force_computation:
            print(f'Gradients up to run {r} already stored in: \n    {tmp_weight_path}\n    {tmp_bias_path}')
            with open(tmp_weight_path, 'rb') as f:
                weight_grads = pickle.load(f)
            with open(tmp_bias_path, 'rb') as f:
                bias_grads = pickle.load(f)
            start_run = r
            break

    model, policy_args, start_epoch, optimiser = load_policy_model(args.policy_model_checkpoint)
    add_base_args(args, policy_args)
    recon_args, recon_model = load_recon_model(policy_args)

    loader = create_data_loader(policy_args, 'train', shuffle=True)
    data_range_dict = create_data_range_dict(policy_args, loader)

    for r in range(start_run, args.data_runs):
        print(f"\n    Run {r + 1} ...")
        cbatch = 0
        tbs = 0
        for it, data in enumerate(loader):  # Randomly shuffled every time
            kspace, masked_kspace, mask, zf, gt, gt_mean, gt_std, fname, sl_idx = data
            cbatch += 1
            tbs += mask.size(0)
            # shape after unsqueeze = batch x channel x columns x rows x complex
            kspace = kspace.unsqueeze(1).to(policy_args.device)
            masked_kspace = masked_kspace.unsqueeze(1).to(policy_args.device)
            mask = mask.unsqueeze(1).to(policy_args.device)
            # shape after unsqueeze = batch x channel x columns x rows
            zf = zf.unsqueeze(1).to(policy_args.device)
            gt = gt.unsqueeze(1).to(policy_args.device)
            gt_mean = gt_mean.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(policy_args.device)
            gt_std = gt_std.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(policy_args.device)
            unnorm_gt = gt * gt_std + gt_mean
            data_range = torch.stack([data_range_dict[vol] for vol in fname])
            recons = recon_model(zf)

            if cbatch == 1:
                optimiser.zero_grad()

            action_list = []
            logprob_list = []
            reward_list = []
            for step in range(policy_args.acquisition_steps):  # Loop over acquisition steps
                loss, mask, masked_kspace, recons = compute_backprop_trajectory(policy_args, kspace, masked_kspace,
                                                                                mask, unnorm_gt, recons, gt_mean,
                                                                                gt_std, data_range, model, recon_model,
                                                                                step, action_list, logprob_list,
                                                                                reward_list)

            if cbatch == policy_args.batches_step:
                # Store gradients for SNR
                num = len(list(model.named_parameters()))
                for i, (name, param) in enumerate(model.named_parameters()):
                    if i == num - 1:  # biases of last layer
                        weight_grads.append(param.grad.cpu().numpy())
                    elif i == num - 2:  # weights of last layer
                        bias_grads.append(param.grad.cpu().numpy())
                cbatch = 0

        print(f"    - Adding grads of run {r + 1} to: \n       {param_dir}")
        with open(weight_path, 'wb') as f:
            pickle.dump(weight_grads, f)
        with open(bias_path, 'wb') as f:
            pickle.dump(bias_grads, f)

    return weight_path, bias_path, param_dir
예제 #5
0
def train_and_eval(args, recon_args, recon_model):
    if args.resume:
        # Check that this works
        resumed = True
        new_run_dir = args.policy_model_checkpoint.parent
        data_path = args.data_path
        # In case models have been moved to a different machine, make sure the path to the recon model is the
        # path provided.
        recon_model_checkpoint = args.recon_model_checkpoint

        model, args, start_epoch, optimiser = load_policy_model(pathlib.Path(
            args.policy_model_checkpoint),
                                                                optim=True)

        args.old_run_dir = args.run_dir
        args.old_recon_model_checkpoint = args.recon_model_checkpoint
        args.old_data_path = args.data_path

        args.recon_model_checkpoint = recon_model_checkpoint
        args.run_dir = new_run_dir
        args.data_path = data_path
        args.resume = True
    else:
        resumed = False
        # Improvement model to train
        model = build_policy_model(args)
        # Add mask parameters for training
        args = add_mask_params(args)
        if args.data_parallel:
            model = torch.nn.DataParallel(model)
        optimiser = build_optim(args, model.parameters())
        start_epoch = 0
        # Create directory to store results in
        savestr = '{}_res{}_al{}_accel{}_k{}_{}_{}'.format(
            args.dataset, args.resolution, args.acquisition_steps,
            args.accelerations, args.num_trajectories,
            datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
            ''.join(choice(ascii_uppercase) for _ in range(5)))
        args.run_dir = args.exp_dir / savestr
        args.run_dir.mkdir(parents=True, exist_ok=False)

    args.resumed = resumed

    if args.wandb:
        allow_val_change = args.resumed  # only allow changes if resumed: otherwise something is wrong.
        wandb.config.update(args, allow_val_change=allow_val_change)
        wandb.watch(model, log='all')

    # Logging
    logging.info(recon_model)
    logging.info(model)
    # Save arguments for bookkeeping
    args_dict = {
        key: str(value)
        for key, value in args.__dict__.items()
        if not key.startswith('__') and not callable(key)
    }
    save_json(args.run_dir / 'args.json', args_dict)

    # Initialise summary writer
    writer = SummaryWriter(log_dir=args.run_dir / 'summary')

    # Parameter counting
    logging.info(
        'Reconstruction model parameters: total {}, of which {} trainable and {} untrainable'
        .format(count_parameters(recon_model),
                count_trainable_parameters(recon_model),
                count_untrainable_parameters(recon_model)))
    logging.info(
        'Policy model parameters: total {}, of which {} trainable and {} untrainable'
        .format(count_parameters(model), count_trainable_parameters(model),
                count_untrainable_parameters(model)))

    if args.scheduler_type == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimiser,
                                                    args.lr_step_size,
                                                    args.lr_gamma)
    elif args.scheduler_type == 'multistep':
        if not isinstance(args.lr_multi_step_size, list):
            args.lr_multi_step_size = [args.lr_multi_step_size]
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimiser, args.lr_multi_step_size, args.lr_gamma)
    else:
        raise ValueError(
            "{} is not a valid scheduler choice ('step', 'multistep')".format(
                args.scheduler_type))

    # Create data loaders
    train_loader = create_data_loader(args, 'train', shuffle=True)
    dev_loader = create_data_loader(args, 'val', shuffle=False)

    train_data_range_dict = create_data_range_dict(args, train_loader)
    dev_data_range_dict = create_data_range_dict(args, dev_loader)

    if not args.resume:
        if args.do_train_ssim:
            do_and_log_evaluation(args, -1, recon_model, model, train_loader,
                                  writer, 'Train', train_data_range_dict)
        do_and_log_evaluation(args, -1, recon_model, model, dev_loader, writer,
                              'Val', dev_data_range_dict)

    for epoch in range(start_epoch, args.num_epochs):
        train_loss, train_time = train_epoch(args, epoch, recon_model, model,
                                             train_loader, optimiser, writer,
                                             train_data_range_dict)
        logging.info(
            f'Epoch = [{epoch+1:3d}/{args.num_epochs:3d}] TrainLoss = {train_loss:.3g} TrainTime = {train_time:.2f}s '
        )

        if args.do_train_ssim:
            do_and_log_evaluation(args, epoch, recon_model, model,
                                  train_loader, writer, 'Train',
                                  train_data_range_dict)
        do_and_log_evaluation(args, epoch, recon_model, model, dev_loader,
                              writer, 'Val', dev_data_range_dict)

        scheduler.step()
        save_policy_model(args, args.run_dir, epoch, model, optimiser)
    writer.close()
예제 #6
0
def main(args):
    """
    Wrapper for running baseline models.

    :param args: Arguments object containing hyperparameters for baseline models.
    """

    # For consistency
    args.val_batch_size = args.batch_size
    # Reconstruction model
    recon_args, recon_model = load_recon_model(args)
    # Add mask parameters for training
    args = add_mask_params(args)

    # Create directory to store results in
    savestr = '{}_res{}_al{}_accel{}_{}_{}_{}'.format(
        args.dataset, args.resolution, args.acquisition_steps,
        args.accelerations, args.model_type,
        datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
        ''.join(choice(ascii_uppercase) for _ in range(5)))
    args.run_dir = args.exp_dir / savestr
    args.run_dir.mkdir(parents=True, exist_ok=False)

    if args.wandb:
        wandb.config.update(args)

    # Logging
    logging.info(args)
    logging.info(recon_model)
    logging.info('Model type: {}'.format(args.model_type))

    # Save arguments for bookkeeping
    args_dict = {
        key: str(value)
        for key, value in args.__dict__.items()
        if not key.startswith('__') and not callable(key)
    }
    save_json(args.run_dir / 'args.json', args_dict)

    # Initialise summary writer
    writer = SummaryWriter(log_dir=args.run_dir / 'summary')

    if args.model_type == 'average_oracle':
        baseline_ssims, baseline_psnrs, baseline_time = run_average_oracle(
            args, recon_model)
    else:
        # Create data loader
        loader = create_data_loader(args, args.partition)
        data_range_dict = create_data_range_dict(args, loader)
        baseline_ssims, baseline_psnrs, baseline_time = run_baseline(
            args, recon_model, loader, data_range_dict)

    # Logging
    ssims_str = ", ".join(
        ["{}: {:.4f}".format(i, l) for i, l in enumerate(baseline_ssims)])
    psnrs_str = ", ".join(
        ["{}: {:.4f}".format(i, l) for i, l in enumerate(baseline_psnrs)])
    logging.info(f'  SSIM = [{ssims_str}]')
    logging.info(f'  PSNR = [{psnrs_str}]')
    logging.info(f'  Time = {baseline_time:.2f}s')

    # For storing in wandb
    for epoch in range(args.num_epochs + 1):
        if args.wandb:
            wandb.log(
                {
                    f'{args.partition}_ssims':
                    {str(key): val
                     for key, val in enumerate(baseline_ssims)}
                },
                step=epoch)
            wandb.log(
                {
                    f'{args.partition}_psnrs':
                    {str(key): val
                     for key, val in enumerate(baseline_psnrs)}
                },
                step=epoch)

    writer.close()