예제 #1
0
    def enhance(self, frame):
        first = time.time()

        bicubic = cv2.resize(frame,
                             dsize=(frame.shape[1] * self.scale,
                                    frame.shape[0] * self.scale),
                             interpolation=cv2.INTER_CUBIC)

        frame = torch.tensor(frame).to(self.device)

        bicubic = torch.tensor(bicubic).to(self.device)

        _, ycbcr = preprocess(bicubic)
        lr, _ = preprocess(frame)

        with torch.no_grad():
            preds = self.model(torch.stack([lr, lr])).clamp(0.0, 1.0)

        preds = torch.squeeze(
            preds.mul(255.0))  # .cpu().numpy().squeeze(0).squeeze(0)

        output = torch.stack([preds, ycbcr[..., 1],
                              ycbcr[..., 2]]).permute([1, 2, 0])
        output = torch.clamp(convert_ycbcr_to_rgb(output), 0.0,
                             255.0).cpu().numpy().astype(np.uint8)
        print(time.time() - first)
        return output
예제 #2
0
def SRCNN2(
    args, image_file
):  # CHANGE TO INPUT THE after-resize IMAGE FILE, SO IN THE OUTPUT3, NEED TO STORE THE denoise+resize image
    # load the SRCNN weights model
    #cudnn.benchmark = True
    device = torch.device('cuda: 0' if torch.cuda.is_available() else 'cpu')
    model = SRCNN().to(device)
    state_dict = model.state_dict()
    weights_dir = os.getcwd() + '\\SRCNN_outputs\\x{}\\'.format(
        args.SR_scale)  #
    weights_file = os.path.join(weights_dir, 'best.pth')  ###
    if not weights_file:
        print(weights_file + ' not exist')
    for n, p in torch.load(weights_file,
                           map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()  # model set in evaluation mode

    img_format = image_file[-4:]
    image = pil_image.open(image_file).convert('RGB')  # 512

    image = np.array(image).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(image)

    y = ycbcr[..., 0]
    y /= 255.
    y = torch.from_numpy(y).to(device)
    y = y.unsqueeze(0).unsqueeze(0)

    with torch.no_grad():
        preds = model(y).clamp(0.0, 1.0)  # output2.size 510

    # psnr = calc_psnr(y, preds)
    # print('PSNR: {:.2f}'.format(psnr))

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(
        0)  # tensor -> np

    output = np.array([preds, ycbcr[..., 1],
                       ycbcr[..., 2]]).transpose([1, 2, 0])  # why transpose
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    return output  ## type in pil_image
예제 #3
0
    feature_map_list = []  # 装feature map
    get_feature_map(model, feature_map_list)
    with torch.no_grad():
        preds = model(lr).clamp(0.0, 1.0)

    show_images("first_feature_map", feature_map_list[0])
    show_images("mid_feature_map", feature_map_list[1])

    show_images("before preds", preds)

    if args.residual:
        psnr = calc_psnr(hr, preds)
        print('preds and hr PSNR: {:.2f}'.format(psnr))
        preds = preds + bicubic

    psnr = calc_psnr(hr, preds)
    print('preds and hr PSNR: {:.2f}'.format(psnr))

    psnr = calc_psnr(hr, bicubic)
    print('bicubic and hr PSNR: {:.2f}'.format(psnr))

    show_images("after preds", preds)
    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

    output = np.array([preds, ycbcr[..., 1], ycbcr[...,
                                                   2]]).transpose([1, 2, 0])
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    output.save(args.image_file.replace('.',
                                        '_fsrcnn_x{}.'.format(args.scale)))
예제 #4
0
파일: test.py 프로젝트: yu2guang/NCTU-CS
        cur_w, cur_h = image.width * args.scale, image.height * args.scale
        image = image.resize(
            (image.width * args.scale, image.height * args.scale),
            resample=pil_image.BICUBIC)

        image = np.array(image).astype(np.float32)
        ycbcr = convert_rgb_to_ycbcr(image)

        y = ycbcr[..., 0]
        y /= 255.
        y = torch.from_numpy(y).to(device)
        y = y.unsqueeze(0).unsqueeze(0)

        with torch.no_grad():
            preds = model(y).clamp(0.0, 1.0)

        psnr = calc_psnr(y, preds)
        psnr_seq.append(psnr.cpu().item())
        print('{} PSNR: {:.2f}'.format(image_path.split('/')[-1], psnr))

        preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

        output = np.array([preds, ycbcr[..., 1],
                           ycbcr[..., 2]]).transpose([1, 2, 0])
        output = np.clip(convert_ycbcr_to_rgb(output), 0.0,
                         255.0).astype(np.uint8)
        output = pil_image.fromarray(output)
        output.save(args.outputs_dir + image_path.split('/')[-1])

    print(f'Average PSNR: {np.mean(psnr_seq)}')
예제 #5
0
def main():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # parse arguments
    pkl_file = '1216_Mixed_public_data_our_FSRCNN_l1_ssim_loss_lr_decay_save_model'
    for filename in sorted(os.listdir(pkl_file)):
        #print(filename)
        if os.path.splitext(filename)[1] == '.pkl':
            print(filename)

            save_img_name = 'epoch' + '_' + filename.split('.')[-2].split('_')[-1]
            print(save_img_name)
            pretrained_model = join(pkl_file, filename)
            print(pretrained_model)

            args = Args(pretrained_model)

            print(args.pretrained_model)
     
            ####load model from M GPU
            model = FSRCNN(scale_factor=4, num_channels=1, d=56, s=12, m=4)

            state_dict = torch.load(args.pretrained_model, map_location = 'cpu')
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
                # load params
            model.load_state_dict(new_state_dict) 
            #model = TCL_SuperResolution(args)
            model.eval()
            ###
            image_dir = args.test_image_dir
            image_filenames = [join(image_dir, x) for x in sorted(listdir(image_dir))]
            ### save the testing dataset
            image_save_dir = args.image_save_dir
            if not os.path.exists(image_save_dir):
                os.mkdir(image_save_dir, mode=0o777)
            
            file_num = len(image_filenames)

            for idx in range(file_num):
                image = pil_image.open(image_filenames[idx]).convert('RGB')
                lr = image
                bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
                lr, _ = preprocess(lr, device)
                # hr, _ = preprocess(hr, device)
                _, ycbcr = preprocess(bicubic, device)

                with torch.no_grad():
                    preds = model(lr).clamp(0.0, 1.0)
    
                # psnr = calc_psnr(hr, preds)
                # print('PSNR: {:.2f}'.format(psnr))
    
                preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
                out_Y = pil_image.fromarray(np.uint8(preds), mode='L')
                # save_path = 'Y_our_fsrcnn_DIV_Flickr_epoch_91' + str(idx) + '.png'
                # out_Y.save(save_path)

                output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
                output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
                output = pil_image.fromarray(output)

                final_image_name = str(idx) '_' + 'fsrcnn' + '_' + save_img_name + '.png'
                print(final_image_name)
                save_path = join(image_save_dir, final_image_name)
                imageio.imsave(save_path, output)
예제 #6
0
        img.size[1] // args.scale) * args.scale

    lr = img.resize((height // args.scale, weight // args.scale),
                    Image.BICUBIC)
    bicubic = lr.resize((height, weight), Image.BICUBIC)
    lr = pre_process(lr.convert('L')).to(device)

    tensor_sr = net(lr)
    img_y = np.array(img.convert('L')) / 255.0
    sr_y = tensor_sr.squeeze(0).squeeze(0).detach().numpy()

    ycbcr = convert_rgb_to_ycbcr(np.array(bicubic)) / 255.0
    sr_ycbcr = np.zeros((sr_y.shape[0], sr_y.shape[1], 3))
    sr_ycbcr[..., 0] = sr_y
    sr_ycbcr[..., 1:3] = ycbcr[..., 1:3]
    sr = convert_ycbcr_to_rgb(sr_ycbcr * 255.0) / 255.0

    mse = np.mean((img_y - sr_y)**2)
    PSNR = psnr_calculate(mse)

    print('PSNR: {:.2f}'.format(PSNR))

    fig = plt.figure()
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.imshow(bicubic)
    ax1.title.set_text('Bicubic')
    ax2 = fig.add_subplot(1, 2, 2)
    ax2.imshow(sr)
    ax2.title.set_text('Reconstructed')
    plt.show()
예제 #7
0
def main(video='/home/zqh/Videos/newland.flv', weights='./fsrcnn_x2.pth',
         threshold=0.5, stride=40, scale=2,
         orginal_method: bool = False,
         export: bool = False):

  if torch.cuda.is_available():
    device = torch.device('cuda')
  else:
    device = torch.device('cpu')

  model = FSRCNN(scale)
  model.load_state_dict(torch.load(weights))
  model = model.eval().to(device)

  ptv = PatchTotalVariation()
  video = Path(video)
  g, length, fps, height, width = get_read_stream(video)
  print(f"Video info: \nheight: {height} width:{width} length:{length} fps:{fps}")
  print(f"Runing info: \nthreshold: {threshold} stride:{stride} scale:{scale} orginal_method:{orginal_method}")
  if export:
    video_export = video.parent / (video.stem + f'_{scale}x_out' + '.mp4')
    writer = get_writer_stream(video_export, fps, height * scale, width * scale)

  plt.ion()
  fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(6, 9))
  title = plt.title('fps:')
  ax1.set_xticks([])
  ax1.set_yticks([])
  ax1im = ax1.imshow(np.zeros((height, width, 3)))
  ax2.set_xticks([])
  ax2.set_yticks([])
  ax2im = ax2.imshow(np.zeros((height // stride, width // stride)), vmin=0, vmax=1)
  ax3.set_xticks([])
  ax3.set_yticks([])
  ax3im = ax3.imshow(np.zeros((height * scale, width * scale, 3)))
  plt.tight_layout()
  plt.show()
  for im in g:
    orginal_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    if orginal_method:
      hr_rgb = pil_image.fromarray(orginal_rgb, mode='RGB').resize(
          (width * scale, height * scale), resample=pil_image.BICUBIC)
      ycbcr = convert_rgb_to_ycbcr(orginal_rgb.astype(np.float32))
      hr_ycbcr = convert_rgb_to_ycbcr(np.array(hr_rgb).astype(np.float32))
      im = torch.tensor(ycbcr[..., 0:1],
                        dtype=torch.float32, device=device)
    else:
      im = torch.tensor(cv2.cvtColor(im, cv2.COLOR_BGR2RGB), dtype=torch.float32, device=device)
    channel = im.shape[-1]

    start_time = time.time()
    # F.interpolate(im.permute((2, 1, 0))[None,...], scale_factor=0.5)[0].permute((1, 2, 0))
    split_im, hw = window_split(im, stride)
    split_tv = ptv(split_im)
    #  NOTE nromlize tv value to [0,1] and set patch color==1 when it's tv>threshold
    split_tv.div_(split_tv.max())
    boolean = split_tv > threshold
    true_idx, true_im, false_idx, false_im = mask_thresh(split_im, hw, boolean)
    split_tv[boolean] = 1

    # NOTE this model only accpect channel==1, so need reshape
    if true_im.numel() > 0:
      true_im = (model(true_im.reshape(-1, 1, *true_im.shape[2:]) / 255.).clamp(0.0, 1.0) * 255.)
      true_im = true_im.byte().reshape((-1, channel, *true_im.shape[2:]))
    # fast interpolate for false_im
    if false_im.numel() > 0:
      false_im = F.interpolate(false_im,
                               scale_factor=scale,
                               mode='bilinear',
                               align_corners=False)
    # merge image
    processed_im = mask_inverse(true_idx, true_im, false_idx, false_im, hw)

    new_im = window_merge(processed_im, hw, stride, scale)

    ax1im.set_data(orginal_rgb)
    ax2im.set_data(split_tv.detach().to('cpu').numpy())

    if orginal_method:
      new_ycbcr = np.concatenate(
          (new_im.detach().to('cpu').numpy(), hr_ycbcr[..., 1:]), -1)
      new_rgb = np.clip(convert_ycbcr_to_rgb(new_ycbcr), 0., 255.).astype('uint8')
    else:
      new_rgb = new_im.detach().to('cpu').numpy().astype('uint8')
    ax3im.set_data(new_rgb)
    title.set_text(f'fps: {1.0 / (time.time() - start_time):.3f}')
    if export:
      writer.write(cv2.cvtColor(new_rgb, cv2.COLOR_RGB2BGR))
    # plt.pcolor
    fig.canvas.flush_events()