Exemplo n.º 1
0
def get_sr_and_score(imset, model, min_L=16):
    '''
    Super resolves an imset with a given model.
    Args:
        imset: imageset
        model: HRNet, pytorch model
        min_L: int, pad length
    Returns:
        sr: tensor (1, C_out, W, H), super resolved image
        scPSNR: float, shift cPSNR score
    '''

    if imset.__class__ is ImageSet:
        collator = collateFunction(min_L=min_L)
        lrs, alphas, hrs, hr_maps, names = collator([imset])
    elif isinstance(imset, tuple):  # imset is a tuple of batches
        lrs, alphas, hrs, hr_maps, names = imset

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    lrs = lrs.float().to(device)
    alphas = alphas.float().to(device)

    sr = model(lrs, alphas)[:, 0]
    sr = sr.detach().cpu().numpy()[0]

    if len(hrs) > 0:
        scPSNR = shift_cPSNR(sr=np.clip(sr, 0, 1),
                             hr=hrs.numpy()[0],
                             hr_map=hr_maps.numpy()[0])
    else:
        scPSNR = None

    return sr, scPSNR
Exemplo n.º 2
0
def main(config):
    """
    Given a configuration, trains HRNet and ShiftNet for Multi-Frame Super Resolution (MFSR), and saves best model.
    Args:
        config: dict, configuration file
    """

    # Reproducibility options
    np.random.seed(0)  # RNG seeds
    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Initialize the network based on the network configuration

    fusion_model = HRNet(config["network"])
    regis_model = ShiftNet()

    optimizer = optim.Adam(list(fusion_model.parameters()) +
                           list(regis_model.parameters()),
                           lr=config["training"]["lr"])  # optim
    # ESA dataset
    data_directory = config["paths"]["prefix"]

    baseline_cpsnrs = None
    if os.path.exists(os.path.join(data_directory, "norm.csv")):
        baseline_cpsnrs = readBaselineCPSNR(
            os.path.join(data_directory, "norm.csv"))

    train_set_directories = getImageSetDirectories(
        os.path.join(data_directory, "train"))

    val_proportion = config['training']['val_proportion']
    train_list, val_list = train_test_split(train_set_directories,
                                            test_size=val_proportion,
                                            random_state=1,
                                            shuffle=True)

    # Dataloaders
    batch_size = config["training"]["batch_size"]
    n_workers = config["training"]["n_workers"]
    n_views = config["training"]["n_views"]
    min_L = config["training"]["min_L"]  # minimum number of views
    beta = config["training"]["beta"]

    train_dataset = ImagesetDataset(imset_dir=train_list,
                                    config=config["training"],
                                    top_k=n_views,
                                    beta=beta)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=n_workers,
                                  collate_fn=collateFunction(min_L=min_L),
                                  pin_memory=True)

    config["training"]["create_patches"] = False
    val_dataset = ImagesetDataset(imset_dir=val_list,
                                  config=config["training"],
                                  top_k=n_views,
                                  beta=beta)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=1,
                                shuffle=False,
                                num_workers=n_workers,
                                collate_fn=collateFunction(min_L=min_L),
                                pin_memory=True)

    dataloaders = {'train': train_dataloader, 'val': val_dataloader}

    # Train model
    torch.cuda.empty_cache()

    #fusion_model.load_state_dict(torch.load("/home/ubadmin/Documents/Scripts/highres_net/HighRes-net-master/models/weights/training_8b_full_ESA/HRNet.pth"))
    #regis_model.load_state_dict(torch.load("/home/ubadmin/Documents/Scripts/highres_net/HighRes-net-master/models/weights/training_8b_full_ESA/ShiftNet.pth"))

    trainAndGetBestModel(fusion_model, regis_model, optimizer, dataloaders,
                         baseline_cpsnrs, config)
Exemplo n.º 3
0
def get_sr_and_score(imset,
                     model,
                     aposterior_gt,
                     next_sr,
                     num_frames,
                     min_L=16):
    '''
    Super resolves an imset with a given model.
    Args:
        imset: imageset
        model: HRNet, pytorch model
        min_L: int, pad length
    Returns:
        sr: tensor (1, C_out, W, H), super resolved image
        scPSNR: float, shift cPSNR score
    '''

    if imset.__class__ is ImageSet:
        collator = collateFunction(num_frames, min_L=min_L)
        lrs, alphas, hrs, hr_maps, names = collator([imset])
    elif isinstance(imset, tuple):  # imset is a tuple of batches
        lrs, alphas, hrs, hr_maps, names = imset

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    #print("LRS SHAPE:", lrs.shape)
    #print("ALPHAS SHAPE", alphas.shape)

    #lrs = lrs[:, :num_frames, :, :]
    #alphas = alphas[:, :num_frames]

    lrs = lrs.float().to(device)
    alphas = alphas.float().to(device)

    sr = model(lrs, alphas)[:, 0]
    sr = sr.detach().cpu().numpy()[0]
    sr = np.clip(sr, 0, 1)

    #    sr = downscale_local_mean(sr, (2, 2))

    cur_hr = hrs.numpy()[0]
    cur_hr_map = hr_maps.numpy()[0]
    cur_sr = sr

    #    cur_hr = downscale_local_mean(cur_hr, (2, 2))
    #    cur_hr_map = downscale_local_mean(cur_hr_map, (2, 2))

    assert (cur_sr.ndim == 2)
    assert (cur_hr.ndim == 2)
    assert (cur_hr_map.ndim == 2)

    if cur_sr.dtype.type is np.uint16:  # integer array is in the range [0, 65536]
        cur_sr = cur_sr / np.iinfo(
            np.uint16).max  # normalize in the range [0, 1]
    else:
        assert 0 <= cur_sr.min() and cur_sr.max(
        ) <= 1, 'sr.dtype must be either uint16 (range 0-65536) or float64 in (0, 1).'
    if cur_hr.dtype.type is np.uint16:
        cur_hr = cur_hr / np.iinfo(np.uint16).max

    if len(hrs) > 0:
        val_gt_SSIM = cSSIM(sr=cur_sr, hr=cur_hr)
        val_L2 = mean_squared_error(cur_hr, cur_sr)
    else:
        val_gt_SSIM = None
        val_L2 = None

    if (str(type(aposterior_gt)) == "<class 'NoneType'>"):
        val_aposterior_SSIM = 1.0
    else:
        val_aposterior_SSIM = cSSIM(sr=cur_sr, hr=aposterior_gt)

    if (str(type(next_sr)) == "<class 'NoneType'>"):
        val_delta_L2 = None
    else:
        assert (next_sr.ndim == 2)
        val_delta_L2 = mean_squared_error(next_sr, cur_sr)

    if len(cur_sr.shape) == 2:
        cur_sr = cur_sr[None, ]
        cur_hr = cur_hr[None, ]
        cur_hr_map = cur_hr_map[None, ]

    if len(hrs) > 0:
        val_cMSE = cMSE(sr=cur_sr, hr=cur_hr, hr_map=cur_hr_map)
        val_cPSNR = -10 * np.log10(val_cMSE)
        val_usual_PSNR = -10 * np.log10(val_L2)
        val_shift_cPSNR = shift_cPSNR(sr=cur_sr, hr=cur_hr, hr_map=cur_hr_map)
        val_shift_cMSE = shift_cMSE(sr=cur_sr, hr=cur_hr, hr_map=cur_hr_map)
    else:
        val_cMSE = None
        val_cPSNR = None
        val_usual_PSNR = None
        val_shift_cPSNR = None
        val_shift_cMSE = None

    if (str(type(next_sr)) == "<class 'NoneType'>"):
        val_delta_cMSE = None
        val_delta_shift_cMSE = None
    else:
        if next_sr.dtype.type is np.uint16:  # integer array is in the range [0, 65536]
            next_sr = next_sr / np.iinfo(
                np.uint16).max  # normalize in the range [0, 1]
        else:
            assert 0 <= next_sr.min() and next_sr.max(
            ) <= 1, 'sr.dtype must be either uint16 (range 0-65536) or float64 in (0, 1).'

        if len(cur_sr.shape) == 2:
            next_sr = next_sr[None, ]

        val_delta_cMSE = cMSE(sr=cur_sr, hr=next_sr, hr_map=cur_hr_map)
        val_delta_shift_cMSE = shift_cMSE(sr=cur_sr,
                                          hr=next_sr,
                                          hr_map=cur_hr_map)


    return sr, val_gt_SSIM, val_aposterior_SSIM, val_cPSNR, val_usual_PSNR, val_shift_cPSNR, val_cMSE, \
           val_L2, val_shift_cMSE, val_delta_cMSE, val_delta_L2, val_delta_shift_cMSE