Example #1
0
def compute_metrics(net, valLoader, upscaling, multiscale):
    net.eval()
    psnr_low = 0
    psnr_super = 0
    ssim_low = 0
    ssim_super = 0
    cpu = torch.device("cpu")
    cuda = torch.device("cuda:0")
    with torch.no_grad():
        for i, batch in enumerate(valLoader):
            highres, lowres = batch
            lowres = lowres.to(cuda)
            highres = highres.to(cuda)
            if (multiscale):
                superres = net(lowres, upscaling)
            else:
                superres = net(lowres)
            lowres = lowres.to(cpu)
            lowres = ToImage()(lowres.view(lowres.size()[1:]))
            lowres = lowres.resize(
                (lowres.size[0] * upscaling, lowres.size[1] * upscaling),
                Image.BICUBIC)
            lowres = ToTensor()(lowres)
            lowres = lowres.view([1] + list(lowres.size()))
            lowres = lowres.to(cuda)
            psnr_super += compute_psnr(superres, highres)
            psnr_low += compute_psnr(lowres, highres)
            ssim_super += compute_msssim(superres, highres)
            print("|")
            ssim_low += compute_msssim(lowres, highres)
    return psnr_low / len(valLoader), psnr_super / len(valLoader), \
           ssim_low / len(valLoader), ssim_super / len(valLoader)
def create_image(net, loader, name, upscaling=2, multiscale=False):
    patch_size = 30
    patches = np.array(
        [[85, 90, 85 + patch_size, 90 + patch_size],
         [160, 140, 160 + patch_size, 140 + patch_size],
         [350, 80, 350 + patch_size, 80 + patch_size]])

    fig = plt.figure(figsize=(15, 18), dpi=100)
    gs = fig.add_gridspec(4, 3, wspace=0.01, hspace=0.3, left=0.05,
                          top=0.95, bottom=0.02, right=0.95)

    highres, lowres = loader[0]
    lowres = lowres.to(torch.device("cuda:0"))
    lowres = lowres.view([1] + list(lowres.size()))
    net.eval()
    with torch.no_grad():
        if (multiscale):
            superres = net(lowres, upscaling)
        else:
            superres = net(lowres)
    lowres = ToImage()(lowres.cpu().view(lowres.size()[1:]))
    superres = ToImage()(superres.cpu().view(superres.size()[1:]))
    highres = ToImage()(highres.cpu().view(highres.size()[1:]))

    lowres_draw = lowres.copy()
    draw = ImageDraw.Draw(lowres_draw)
    draw.rectangle(list(patches[0]), outline='white')
    draw.rectangle(list(patches[1]), outline='white')
    draw.rectangle(list(patches[2]), outline='white')

    lowres = lowres.resize((lowres.size[0] * upscaling,
                            lowres.size[1] * upscaling),
                           Image.BICUBIC)

    psnr_low, psnr_super, ssim_low, ssim_super = compute_metrics(
        net, DataLoader(loader), upscaling, multiscale)
    lowres_title = "Low-resolution image" +\
    "\nAverage PSNR over the dataset: {:.2f}\n".format(psnr_low) +\
    "Average SSIM over the dataset: {:.4f}".format(ssim_low)
    superres_title = "Reconstructed image\n" +\
    "Average PSNR over the dataset: {:.2f}\n".format(psnr_super) +\
    "Average SSIM over the dataset: {:.4f}".format(ssim_super)
    fig.add_subplot(gs[0, 0], xticks=[], yticks=[],
                    ylabel=f"Image", title=lowres_title)
    plt.imshow(np.array(lowres_draw))
    fig.add_subplot(gs[0, 1], xticks=[], yticks=[],
                    title=superres_title)
    plt.imshow(np.array(superres))
    fig.add_subplot(gs[0, 2], xticks=[], yticks=[],
                    title="High-resolution image")
    plt.imshow(np.array(highres))

    ylabels = ["Patch 1", "Patch 2", "Patch 3"]
    for i in range(3):
        print(lowres.size, highres.size, superres.size)
        lowres_patch = lowres.crop(patches[i] * upscaling)
        highres_patch = highres.crop(patches[i] * upscaling)
        superres_patch = superres.crop(patches[i] * upscaling)

        fig.add_subplot(gs[1 + i, 0], xticks=[], yticks=[],
                        ylabel=ylabels[i])
        plt.imshow(np.array(lowres_patch))
        fig.add_subplot(gs[1 + i, 1], xticks=[], yticks=[])
        plt.imshow(np.array(superres_patch))
        fig.add_subplot(gs[1 + i, 2], xticks=[], yticks=[])
        plt.imshow(np.array(highres_patch))

    plt.savefig(name)