示例#1
0
def get_sr_and_score(imset, model, min_L=16):
    '''
    Super resolves an imset with a given model.
    Args:
        imset: imageset
        model: HRNet, pytorch model
        min_L: int, pad length
    Returns:
        sr: tensor (1, C_out, W, H), super resolved image
        scPSNR: float, shift cPSNR score
    '''

    if imset.__class__ is ImageSet:
        collator = collateFunction(min_L=min_L)
        lrs, alphas, hrs, hr_maps, names = collator([imset])
    elif isinstance(imset, tuple):  # imset is a tuple of batches
        lrs, alphas, hrs, hr_maps, names = imset

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    lrs = lrs.float().to(device)
    alphas = alphas.float().to(device)

    sr = model(lrs, alphas)[:, 0]
    sr = sr.detach().cpu().numpy()[0]

    if len(hrs) > 0:
        scPSNR = shift_cPSNR(sr=np.clip(sr, 0, 1),
                             hr=hrs.numpy()[0],
                             hr_map=hr_maps.numpy()[0])
    else:
        scPSNR = None

    return sr, scPSNR
示例#2
0
def trainAndGetBestModel(fusion_model, regis_model, optimizer, dataloaders,
                         baseline_cpsnrs, config):
    """
    Trains HRNet and ShiftNet for Multi-Frame Super Resolution (MFSR), and saves best model.
    Args:
        fusion_model: torch.model, HRNet
        regis_model: torch.model, ShiftNet
        optimizer: torch.optim, optimizer to minimize loss
        dataloaders: dict, wraps train and validation dataloaders
        baseline_cpsnrs: dict, ESA baseline scores
        config: dict, configuration file
    """
    np.random.seed(123)  # seed all RNGs for reproducibility
    torch.manual_seed(123)

    num_epochs = config["training"]["num_epochs"]
    batch_size = config["training"]["batch_size"]
    n_views = config["training"]["n_views"]
    min_L = config["training"]["min_L"]  # minimum number of views
    beta = config["training"]["beta"]

    subfolder_pattern = 'batch_{}_views_{}_min_{}_beta_{}_time_{}'.format(
        batch_size, n_views, min_L, beta,
        f"{datetime.datetime.now():%Y-%m-%d-%H-%M-%S-%f}")

    checkpoint_dir_run = os.path.join(config["paths"]["checkpoint_dir"],
                                      subfolder_pattern)
    os.makedirs(checkpoint_dir_run, exist_ok=True)

    tb_logging_dir = config['paths']['tb_log_file_dir']
    logging_dir = os.path.join(tb_logging_dir, subfolder_pattern)
    os.makedirs(logging_dir, exist_ok=True)

    writer = SummaryWriter(logging_dir)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    best_score = 100

    P = config["training"]["patch_size"]
    offset = (3 * config["training"]["patch_size"] - 128) // 2
    C = config["training"]["crop"]
    torch_mask = get_crop_mask(patch_size=P, crop_size=C)
    torch_mask = torch_mask.to(device)  # crop borders (loss)

    fusion_model.to(device)
    regis_model.to(device)

    scheduler = lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=config['training']['lr_decay'],
        verbose=True,
        patience=config['training']['lr_step'])

    for epoch in tqdm(range(1, num_epochs + 1)):

        # Train
        fusion_model.train()
        regis_model.train()
        train_loss = 0.0  # monitor train loss

        # Iterate over data.
        for lrs, alphas, hrs, hr_maps, names in tqdm(dataloaders['train']):

            optimizer.zero_grad()  # zero the parameter gradients
            lrs = lrs.float().to(device)
            alphas = alphas.float().to(device)
            hr_maps = hr_maps.float().to(device)
            hrs = hrs.float().to(device)

            # torch.autograd.set_detect_anomaly(mode=True)
            srs = fusion_model(lrs,
                               alphas)  # fuse multi frames (B, 1, 3*W, 3*H)

            # Register batch wrt HR
            shifts = register_batch(regis_model,
                                    srs[:, :, offset:(offset + 128),
                                        offset:(offset + 128)],
                                    reference=hrs[:, offset:(offset + 128),
                                                  offset:(offset + 128)].view(
                                                      -1, 1, 128, 128))
            srs_shifted = apply_shifts(regis_model, srs, shifts, device)[:, 0]

            # Training loss
            cropped_mask = hr_maps  #torch_mask[0] * hr_maps  # Compute current mask (Batch size, W, H)
            # srs_shifted = torch.clamp(srs_shifted, min=0.0, max=1.0)  # correct over/under-shoots
            loss = -get_loss(srs_shifted, hrs, cropped_mask, metric='cPSNR')
            loss = torch.mean(loss)
            loss += config["training"]["lambda"] * torch.mean(shifts)**2

            # Backprop
            loss.backward()
            optimizer.step()
            epoch_loss = loss.detach().cpu().numpy() * len(hrs) / len(
                dataloaders['train'].dataset)
            train_loss += epoch_loss

        # Eval
        fusion_model.eval()
        val_score = 0.0  # monitor val score

        for lrs, alphas, hrs, hr_maps, names in dataloaders['val']:
            lrs = lrs.float().to(device)
            alphas = alphas.float().to(device)
            hrs = hrs.numpy()
            hr_maps = hr_maps.numpy()

            srs = fusion_model(lrs,
                               alphas)[:,
                                       0]  # fuse multi frames (B, 1, 3*W, 3*H)

            # compute ESA score
            srs = srs.detach().cpu().numpy()
            baseline_cpsnrs = None
            for i in range(srs.shape[0]):  # batch size

                if baseline_cpsnrs is None:
                    val_score -= shift_cPSNR(np.clip(srs[i], 0, 1), hrs[i],
                                             hr_maps[i])
                #else:
                #ESA = baseline_cpsnrs[names[i]]
                #val_score += ESA / shift_cPSNR(np.clip(srs[i], 0, 1), hrs[i], hr_maps[i])

        val_score /= len(dataloaders['val'].dataset)

        #if best_score > val_score:
        torch.save(fusion_model.state_dict(),
                   os.path.join(checkpoint_dir_run, 'HRNet.pth'))
        torch.save(regis_model.state_dict(),
                   os.path.join(checkpoint_dir_run, 'ShiftNet.pth'))
        best_score = val_score

        writer.add_image('SR Image',
                         (srs[0] - np.min(srs[0])) / np.max(srs[0]),
                         epoch,
                         dataformats='HW')
        error_map = hrs[0] - srs[0]
        writer.add_image('Error Map', error_map, epoch, dataformats='HW')
        writer.add_scalar("train/loss", train_loss, epoch)
        writer.add_scalar("train/val_loss", val_score, epoch)
        scheduler.step(val_score)
    writer.close()
示例#3
0
def get_sr_and_score(imset,
                     model,
                     aposterior_gt,
                     next_sr,
                     num_frames,
                     min_L=16):
    '''
    Super resolves an imset with a given model.
    Args:
        imset: imageset
        model: HRNet, pytorch model
        min_L: int, pad length
    Returns:
        sr: tensor (1, C_out, W, H), super resolved image
        scPSNR: float, shift cPSNR score
    '''

    if imset.__class__ is ImageSet:
        collator = collateFunction(num_frames, min_L=min_L)
        lrs, alphas, hrs, hr_maps, names = collator([imset])
    elif isinstance(imset, tuple):  # imset is a tuple of batches
        lrs, alphas, hrs, hr_maps, names = imset

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    #print("LRS SHAPE:", lrs.shape)
    #print("ALPHAS SHAPE", alphas.shape)

    #lrs = lrs[:, :num_frames, :, :]
    #alphas = alphas[:, :num_frames]

    lrs = lrs.float().to(device)
    alphas = alphas.float().to(device)

    sr = model(lrs, alphas)[:, 0]
    sr = sr.detach().cpu().numpy()[0]
    sr = np.clip(sr, 0, 1)

    #    sr = downscale_local_mean(sr, (2, 2))

    cur_hr = hrs.numpy()[0]
    cur_hr_map = hr_maps.numpy()[0]
    cur_sr = sr

    #    cur_hr = downscale_local_mean(cur_hr, (2, 2))
    #    cur_hr_map = downscale_local_mean(cur_hr_map, (2, 2))

    assert (cur_sr.ndim == 2)
    assert (cur_hr.ndim == 2)
    assert (cur_hr_map.ndim == 2)

    if cur_sr.dtype.type is np.uint16:  # integer array is in the range [0, 65536]
        cur_sr = cur_sr / np.iinfo(
            np.uint16).max  # normalize in the range [0, 1]
    else:
        assert 0 <= cur_sr.min() and cur_sr.max(
        ) <= 1, 'sr.dtype must be either uint16 (range 0-65536) or float64 in (0, 1).'
    if cur_hr.dtype.type is np.uint16:
        cur_hr = cur_hr / np.iinfo(np.uint16).max

    if len(hrs) > 0:
        val_gt_SSIM = cSSIM(sr=cur_sr, hr=cur_hr)
        val_L2 = mean_squared_error(cur_hr, cur_sr)
    else:
        val_gt_SSIM = None
        val_L2 = None

    if (str(type(aposterior_gt)) == "<class 'NoneType'>"):
        val_aposterior_SSIM = 1.0
    else:
        val_aposterior_SSIM = cSSIM(sr=cur_sr, hr=aposterior_gt)

    if (str(type(next_sr)) == "<class 'NoneType'>"):
        val_delta_L2 = None
    else:
        assert (next_sr.ndim == 2)
        val_delta_L2 = mean_squared_error(next_sr, cur_sr)

    if len(cur_sr.shape) == 2:
        cur_sr = cur_sr[None, ]
        cur_hr = cur_hr[None, ]
        cur_hr_map = cur_hr_map[None, ]

    if len(hrs) > 0:
        val_cMSE = cMSE(sr=cur_sr, hr=cur_hr, hr_map=cur_hr_map)
        val_cPSNR = -10 * np.log10(val_cMSE)
        val_usual_PSNR = -10 * np.log10(val_L2)
        val_shift_cPSNR = shift_cPSNR(sr=cur_sr, hr=cur_hr, hr_map=cur_hr_map)
        val_shift_cMSE = shift_cMSE(sr=cur_sr, hr=cur_hr, hr_map=cur_hr_map)
    else:
        val_cMSE = None
        val_cPSNR = None
        val_usual_PSNR = None
        val_shift_cPSNR = None
        val_shift_cMSE = None

    if (str(type(next_sr)) == "<class 'NoneType'>"):
        val_delta_cMSE = None
        val_delta_shift_cMSE = None
    else:
        if next_sr.dtype.type is np.uint16:  # integer array is in the range [0, 65536]
            next_sr = next_sr / np.iinfo(
                np.uint16).max  # normalize in the range [0, 1]
        else:
            assert 0 <= next_sr.min() and next_sr.max(
            ) <= 1, 'sr.dtype must be either uint16 (range 0-65536) or float64 in (0, 1).'

        if len(cur_sr.shape) == 2:
            next_sr = next_sr[None, ]

        val_delta_cMSE = cMSE(sr=cur_sr, hr=next_sr, hr_map=cur_hr_map)
        val_delta_shift_cMSE = shift_cMSE(sr=cur_sr,
                                          hr=next_sr,
                                          hr_map=cur_hr_map)


    return sr, val_gt_SSIM, val_aposterior_SSIM, val_cPSNR, val_usual_PSNR, val_shift_cPSNR, val_cMSE, \
           val_L2, val_shift_cMSE, val_delta_cMSE, val_delta_L2, val_delta_shift_cMSE