コード例 #1
0
def SRCNN2(
    args, image_file
):  # CHANGE TO INPUT THE after-resize IMAGE FILE, SO IN THE OUTPUT3, NEED TO STORE THE denoise+resize image
    # load the SRCNN weights model
    #cudnn.benchmark = True
    device = torch.device('cuda: 0' if torch.cuda.is_available() else 'cpu')
    model = SRCNN().to(device)
    state_dict = model.state_dict()
    weights_dir = os.getcwd() + '\\SRCNN_outputs\\x{}\\'.format(
        args.SR_scale)  #
    weights_file = os.path.join(weights_dir, 'best.pth')  ###
    if not weights_file:
        print(weights_file + ' not exist')
    for n, p in torch.load(weights_file,
                           map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()  # model set in evaluation mode

    img_format = image_file[-4:]
    image = pil_image.open(image_file).convert('RGB')  # 512

    image = np.array(image).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(image)

    y = ycbcr[..., 0]
    y /= 255.
    y = torch.from_numpy(y).to(device)
    y = y.unsqueeze(0).unsqueeze(0)

    with torch.no_grad():
        preds = model(y).clamp(0.0, 1.0)  # output2.size 510

    # psnr = calc_psnr(y, preds)
    # print('PSNR: {:.2f}'.format(psnr))

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(
        0)  # tensor -> np

    output = np.array([preds, ycbcr[..., 1],
                       ycbcr[..., 2]]).transpose([1, 2, 0])  # why transpose
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    return output  ## type in pil_image
コード例 #2
0
def main():
    dataloaders = myDataloader()
    test_loader = dataloaders.getTestLoader(batch_size)

    model = SRCNN().cuda()
    model.load_state_dict(torch.load("./result/train/20srcnnParms.pth"))
    model.eval()
    with torch.no_grad():
        for i, (pic, blurPic, index) in enumerate(test_loader):
            blurPic = blurPic.cuda()
            out = model(blurPic).cpu()
            blurPic = blurPic.cpu()
            showImages(pic[0])
            showImages(blurPic[0])
            showImages(out[0])
            print(index)
            break
コード例 #3
0
ファイル: test.py プロジェクト: ruczhouyujie/zeroshot
def main():
    dataloaders = myDataloader()
    test_loader = dataloaders.getTestLoader(batch_size)

    model = SRCNN().cuda()
    model.load_state_dict(
        torch.load("./result/train/" + str(epoch) + "srcnnParms.pth"))
    model.eval()
    with torch.no_grad():
        for i, (pic, blurPic, index) in enumerate(test_loader):
            pic = pic.cuda()
            blurPic = blurPic.cuda()
            out = model(blurPic)
            res = torch.cosine_similarity(pic, out, dim=1)
            res = res[0]
            minValue = torch.min(res)
            meanValue = res.mean()
            output = 1 - ((res >= meanValue) + 0)
            plt.figure()
            plt.title(index)
            plt.imshow(output.cpu(), cmap="gray")
            plt.show()
            break
コード例 #4
0
    cudnn.benchmark = True
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model = SRCNN().to(device)

    state_dict = model.state_dict()
    for n, p in torch.load(args.weights_file,
                           map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    # model, optim = torch.load(model.state_dict(), os.path.join(args.weights_file, 'epoch_150.pth'))

    model.eval()

    image = Image.open(args.image_file).convert('RGB')
    resample = image

    image_width = (image.width // args.scale) * args.scale
    image_height = (image.height // args.scale) * args.scale
    image = image.resize((image_width, image_height), resample=Image.BICUBIC)
    image = image.resize(
        (image.width // args.scale, image.height // args.scale),
        resample=Image.BICUBIC)
    image = image.resize((image.width * args.scale, image.height * args.scale),
                         resample=Image.BICUBIC)
    image.save(args.image_file.replace('.',
                                       '_bicubic_x{}.'.format(args.scale)))
コード例 #5
0
    return vars(parser.parse_args())


if __name__ == '__main__':
    args = get_arguments()

    weight = args.get('weight')
    p = args.get('input')
    upscale = args.get('scale')

    dirpath = os.path.dirname(p)
    input_filename = os.path.basename(p)

    gen = SRCNN()
    gen.load_state_dict(torch.load(weight))
    gen.eval()

    img = Image.open(p)
    new_size = [int(x * upscale) for x in img.size]

    converter = T.Compose([
        T.Resize(size=new_size[::-1], interpolation=Image.BICUBIC),
        T.ToTensor()
    ])

    with torch.no_grad():
        x = converter(img)
        pred = gen(x[None, :, :, :])[0]

    save_image(pred, os.path.join(dirpath, f'outputx{upscale}_{input_filename}'))
コード例 #6
0
    img = Image.open(path).convert('YCbCr')
    y, cb, cr = img.split()
    img = img.resize(
        (int(img.size[0] * zoom_factor), int(img.size[1] * zoom_factor)),
        Image.BICUBIC)  # first, we upscale the image via bicubic interpolation
    img.save(f'{path.stem}_bicubic.jpg')

    img_to_tensor = transforms.ToTensor()
    input = img_to_tensor(y).view(
        1, -1, y.size[1], y.size[0])  # we only work with the "Y" channel

    device = torch.device("cpu")
    model = SRCNN()
    model.load_state_dict(torch.load('model_21.pth'))
    model = model.eval().to(device)

    input = input.to(device)
    with torch.no_grad():
        out = model(input)
        out = out.cpu()

    out_img_y = out[0].numpy()
    out_img_y *= 255.0
    out_img_y = out_img_y.clip(0, 255)
    out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')

    out_img = Image.merge(
        'YCbCr',
        [out_img_y,
         cb.resize(out_img_y.size),