Ejemplo n.º 1
0
    def inference(self, input_image, output_dir, save_images=False):
        # Create scaled image
        scaled_image = resize_image(input_image, 2)

        # Create y and scaled y image
        input_y_image = convert_rgb_to_y(input_image)
        scaled_y_image = resize_image(input_y_image, self.scale)

        output_y_image = self.run(input_y_image, scaled_y_image)

        # Create result image
        scaled_ycbcr_image = convert_rgb_to_ycbcr(scaled_image)
        result_image = convert_y_and_cbcr_to_rgb(output_y_image,
                                                 scaled_ycbcr_image[:, :, 1:3])

        if save_images:
            save_image(input_image, "{}/original.jpg".format(output_dir))
            save_image(scaled_image, "{}/bicubic.jpg".format(output_dir))
            save_image(scaled_y_image,
                       "{}/bicubic_y.jpg".format(output_dir),
                       is_rgb=False)
            save_image(output_y_image,
                       "{}/result_y.jpg".format(output_dir),
                       is_rgb=False)
            save_image(result_image, "{}/result.jpg".format(output_dir))

        return result_image
Ejemplo n.º 2
0
def SRCNN2(
    args, image_file
):  # CHANGE TO INPUT THE after-resize IMAGE FILE, SO IN THE OUTPUT3, NEED TO STORE THE denoise+resize image
    # load the SRCNN weights model
    #cudnn.benchmark = True
    device = torch.device('cuda: 0' if torch.cuda.is_available() else 'cpu')
    model = SRCNN().to(device)
    state_dict = model.state_dict()
    weights_dir = os.getcwd() + '\\SRCNN_outputs\\x{}\\'.format(
        args.SR_scale)  #
    weights_file = os.path.join(weights_dir, 'best.pth')  ###
    if not weights_file:
        print(weights_file + ' not exist')
    for n, p in torch.load(weights_file,
                           map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()  # model set in evaluation mode

    img_format = image_file[-4:]
    image = pil_image.open(image_file).convert('RGB')  # 512

    image = np.array(image).astype(np.float32)
    ycbcr = convert_rgb_to_ycbcr(image)

    y = ycbcr[..., 0]
    y /= 255.
    y = torch.from_numpy(y).to(device)
    y = y.unsqueeze(0).unsqueeze(0)

    with torch.no_grad():
        preds = model(y).clamp(0.0, 1.0)  # output2.size 510

    # psnr = calc_psnr(y, preds)
    # print('PSNR: {:.2f}'.format(psnr))

    preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(
        0)  # tensor -> np

    output = np.array([preds, ycbcr[..., 1],
                       ycbcr[..., 2]]).transpose([1, 2, 0])  # why transpose
    output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
    output = pil_image.fromarray(output)
    return output  ## type in pil_image
Ejemplo n.º 3
0
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()
    psnr_seq = []
    for image_path in sorted(glob.glob('{}*.png'.format(args.images_dir))):

        image = pil_image.open(image_path).convert('RGB')
        cur_w, cur_h = image.width * args.scale, image.height * args.scale
        image = image.resize(
            (image.width * args.scale, image.height * args.scale),
            resample=pil_image.BICUBIC)

        image = np.array(image).astype(np.float32)
        ycbcr = convert_rgb_to_ycbcr(image)

        y = ycbcr[..., 0]
        y /= 255.
        y = torch.from_numpy(y).to(device)
        y = y.unsqueeze(0).unsqueeze(0)

        with torch.no_grad():
            preds = model(y).clamp(0.0, 1.0)

        psnr = calc_psnr(y, preds)
        psnr_seq.append(psnr.cpu().item())
        print('{} PSNR: {:.2f}'.format(image_path.split('/')[-1], psnr))

        preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
Ejemplo n.º 4
0
        net.eval()

    img = Image.open(args.test_img, mode='r').convert('RGB')
    height, weight = (img.size[0] // args.scale) * args.scale, (
        img.size[1] // args.scale) * args.scale

    lr = img.resize((height // args.scale, weight // args.scale),
                    Image.BICUBIC)
    bicubic = lr.resize((height, weight), Image.BICUBIC)
    lr = pre_process(lr.convert('L')).to(device)

    tensor_sr = net(lr)
    img_y = np.array(img.convert('L')) / 255.0
    sr_y = tensor_sr.squeeze(0).squeeze(0).detach().numpy()

    ycbcr = convert_rgb_to_ycbcr(np.array(bicubic)) / 255.0
    sr_ycbcr = np.zeros((sr_y.shape[0], sr_y.shape[1], 3))
    sr_ycbcr[..., 0] = sr_y
    sr_ycbcr[..., 1:3] = ycbcr[..., 1:3]
    sr = convert_ycbcr_to_rgb(sr_ycbcr * 255.0) / 255.0

    mse = np.mean((img_y - sr_y)**2)
    PSNR = psnr_calculate(mse)

    print('PSNR: {:.2f}'.format(PSNR))

    fig = plt.figure()
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.imshow(bicubic)
    ax1.title.set_text('Bicubic')
    ax2 = fig.add_subplot(1, 2, 2)
Ejemplo n.º 5
0
def main(video='/home/zqh/Videos/newland.flv', weights='./fsrcnn_x2.pth',
         threshold=0.5, stride=40, scale=2,
         orginal_method: bool = False,
         export: bool = False):

  if torch.cuda.is_available():
    device = torch.device('cuda')
  else:
    device = torch.device('cpu')

  model = FSRCNN(scale)
  model.load_state_dict(torch.load(weights))
  model = model.eval().to(device)

  ptv = PatchTotalVariation()
  video = Path(video)
  g, length, fps, height, width = get_read_stream(video)
  print(f"Video info: \nheight: {height} width:{width} length:{length} fps:{fps}")
  print(f"Runing info: \nthreshold: {threshold} stride:{stride} scale:{scale} orginal_method:{orginal_method}")
  if export:
    video_export = video.parent / (video.stem + f'_{scale}x_out' + '.mp4')
    writer = get_writer_stream(video_export, fps, height * scale, width * scale)

  plt.ion()
  fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(6, 9))
  title = plt.title('fps:')
  ax1.set_xticks([])
  ax1.set_yticks([])
  ax1im = ax1.imshow(np.zeros((height, width, 3)))
  ax2.set_xticks([])
  ax2.set_yticks([])
  ax2im = ax2.imshow(np.zeros((height // stride, width // stride)), vmin=0, vmax=1)
  ax3.set_xticks([])
  ax3.set_yticks([])
  ax3im = ax3.imshow(np.zeros((height * scale, width * scale, 3)))
  plt.tight_layout()
  plt.show()
  for im in g:
    orginal_rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    if orginal_method:
      hr_rgb = pil_image.fromarray(orginal_rgb, mode='RGB').resize(
          (width * scale, height * scale), resample=pil_image.BICUBIC)
      ycbcr = convert_rgb_to_ycbcr(orginal_rgb.astype(np.float32))
      hr_ycbcr = convert_rgb_to_ycbcr(np.array(hr_rgb).astype(np.float32))
      im = torch.tensor(ycbcr[..., 0:1],
                        dtype=torch.float32, device=device)
    else:
      im = torch.tensor(cv2.cvtColor(im, cv2.COLOR_BGR2RGB), dtype=torch.float32, device=device)
    channel = im.shape[-1]

    start_time = time.time()
    # F.interpolate(im.permute((2, 1, 0))[None,...], scale_factor=0.5)[0].permute((1, 2, 0))
    split_im, hw = window_split(im, stride)
    split_tv = ptv(split_im)
    #  NOTE nromlize tv value to [0,1] and set patch color==1 when it's tv>threshold
    split_tv.div_(split_tv.max())
    boolean = split_tv > threshold
    true_idx, true_im, false_idx, false_im = mask_thresh(split_im, hw, boolean)
    split_tv[boolean] = 1

    # NOTE this model only accpect channel==1, so need reshape
    if true_im.numel() > 0:
      true_im = (model(true_im.reshape(-1, 1, *true_im.shape[2:]) / 255.).clamp(0.0, 1.0) * 255.)
      true_im = true_im.byte().reshape((-1, channel, *true_im.shape[2:]))
    # fast interpolate for false_im
    if false_im.numel() > 0:
      false_im = F.interpolate(false_im,
                               scale_factor=scale,
                               mode='bilinear',
                               align_corners=False)
    # merge image
    processed_im = mask_inverse(true_idx, true_im, false_idx, false_im, hw)

    new_im = window_merge(processed_im, hw, stride, scale)

    ax1im.set_data(orginal_rgb)
    ax2im.set_data(split_tv.detach().to('cpu').numpy())

    if orginal_method:
      new_ycbcr = np.concatenate(
          (new_im.detach().to('cpu').numpy(), hr_ycbcr[..., 1:]), -1)
      new_rgb = np.clip(convert_ycbcr_to_rgb(new_ycbcr), 0., 255.).astype('uint8')
    else:
      new_rgb = new_im.detach().to('cpu').numpy().astype('uint8')
    ax3im.set_data(new_rgb)
    title.set_text(f'fps: {1.0 / (time.time() - start_time):.3f}')
    if export:
      writer.write(cv2.cvtColor(new_rgb, cv2.COLOR_RGB2BGR))
    # plt.pcolor
    fig.canvas.flush_events()