Exemplo n.º 1
0
def test(model, Dataset, save_dir=args.save_dir, output_name='output.png'):
        test_dir = args.test_dir
        patch_size = args.patch_size
         
        avg_psnr = 0    
        test_loader = DataLoader(dataset=Dataset, batch_size=1, shuffle=False, num_workers=0)

        total_batches = 0 
        for batch_index, (Frame1, Frame2, Frame3) in enumerate(test_loader):
            Frame1 = to_variable(Frame1)
            Frame2 = to_variable(Frame2)
            Frame3 = to_variable(Frame3)
            frame_out = model(Frame1, Frame3)
            gt = Frame2
            psnr = -10 * log10(torch.mean((gt - frame_out) * (gt - frame_out)).item())
            avg_psnr += psnr
            imwrite(frame_out, save_dir + '/' + '{}'.format(batch_index) + output_name, range=(0, 1))
            
            msg = "Batch: {}\t PSNR: {:<20.4f}\n".format(batch_index, psnr)
            total_batches += 1
            #print(msg, end='')
            #logfile.write(msg)
            
        avg_psnr /= total_batches
        return avg_psnr
 def Test(self,
          model,
          output_dir,
          current_epoch,
          logfile=None,
          output_name='output.png'):
     model.eval()
     av_psnr = 0
     if logfile is not None:
         logfile.write('{:<7s}{:<3d}'.format('Epoch: ', current_epoch) +
                       '\n')
     for idx in range(len(self.im_list)):
         if not os.path.exists(output_dir + '/' + self.im_list[idx]):
             os.makedirs(output_dir + '/' + self.im_list[idx])
         frame_out = model(self.input0_list[idx], self.input1_list[idx])
         gt = self.gt_list[idx]
         psnr = -10 * log10(
             torch.mean((gt - frame_out) * (gt - frame_out)).item())
         av_psnr += psnr
         imwrite(frame_out,
                 output_dir + '/' + self.im_list[idx] + '/' + output_name,
                 range=(0, 1))
         msg = '{:<15s}{:<20.16f}'.format(self.im_list[idx] + ': ',
                                          psnr) + '\n'
         print(msg, end='')
         if logfile is not None:
             logfile.write(msg)
     av_psnr /= len(self.im_list)
     msg = '{:<15s}{:<20.16f}'.format('Average: ', av_psnr) + '\n'
     print(msg, end='')
     if logfile is not None:
         logfile.write(msg)
Exemplo n.º 3
0
def main():
    args = parser.parse_args()
    torch.cuda.set_device(args.gpu_id)

    config_file = open(args.config, 'r')
    while True:
        line = config_file.readline()
        if not line:
            break
        if line.find(':') == '0':
            continue
        else:
            tmp_list = line.split(': ')
            if tmp_list[0] == 'kernel_size':
                args.kernel_size = int(tmp_list[1])
            if tmp_list[0] == 'flow_num':
                args.flow_num = int(tmp_list[1])
            if tmp_list[0] == 'dilation':
                args.dilation = int(tmp_list[1])
    config_file.close()

    model = models.Model(args)

    checkpoint = torch.load(args.checkpoint, map_location=torch.device('cpu'))
    model.load(checkpoint['state_dict'])

    frame_name1 = args.first_frame
    frame_name2 = args.second_frame

    frame1 = to_variable(transform(Image.open(frame_name1)).unsqueeze(0))
    frame2 = to_variable(transform(Image.open(frame_name2)).unsqueeze(0))

    model.eval()
    frame_out = model(frame1, frame2)
    imwrite(frame_out.clone(), args.output_frame, range=(0, 1))
Exemplo n.º 4
0
    def test(self, model, output_dir, output_name='output', file_stream=None):
        model.eval()
        with torch.no_grad():
            av_ssim = 0
            av_psnr = 0
            av_lpips = 0
            print('%25s%21s%21s' % ('PSNR', 'SSIM', 'lpips'))
            for idx in range(len(self.im_list)):
                if not os.path.exists(output_dir + '/' + self.im_list[idx]):
                    os.makedirs(output_dir + '/' + self.im_list[idx])

                in0, in1 = self.input0_list[idx].unsqueeze(
                    0).cuda(), self.input1_list[idx].unsqueeze(0).cuda()
                frame_out = model(in0, in1)

                lps = lpips(self.gt_list[idx].cuda(),
                            frame_out,
                            net_type='squeeze')

                imwrite(frame_out,
                        output_dir + '/' + self.im_list[idx] + '/' +
                        output_name + '.png',
                        range=(0, 1))

                frame_out = frame_out.squeeze().detach().cpu().numpy()
                gt = self.gt_list[idx].numpy()

                psnr = skimage.metrics.peak_signal_noise_ratio(
                    image_true=gt, image_test=frame_out)
                ssim = skimage.metrics.structural_similarity(
                    np.transpose(gt, (1, 2, 0)),
                    np.transpose(frame_out, (1, 2, 0)),
                    multichannel=True)

                av_psnr += psnr
                av_ssim += ssim
                av_lpips += lps.item()

                msg = '{:<15s}{:<20.16f}{:<23.16f}{:<23.16f}'.format(
                    self.im_list[idx] + ': ', psnr, ssim, lps.item())
                if file_stream:
                    print_and_save(msg, file_stream)
                else:
                    print(msg)

                self.gt_list[idx].to('cpu')

        av_psnr /= len(self.im_list)
        av_ssim /= len(self.im_list)
        av_lpips /= len(self.im_list)
        msg = '\n{:<15s}{:<20.16f}{:<23.16f}{:<23.16f}'.format(
            'Average: ', av_psnr, av_ssim, av_lpips)
        if file_stream:
            print_and_save(msg, file_stream)
        else:
            print(msg)

        return av_psnr
 def Test(self,
          model,
          output_dir='./evaluation/output',
          output_name='frame10i11.png'):
     model.eval()
     for idx in range(len(self.im_list)):
         if not os.path.exists(output_dir + '/' + self.im_list[idx]):
             os.makedirs(output_dir + '/' + self.im_list[idx])
         frame_out = model(self.input0_list[idx], self.input1_list[idx])
         imwrite(frame_out,
                 output_dir + '/' + self.im_list[idx] + '/' + output_name,
                 range=(0, 1))
Exemplo n.º 6
0
    def Test(self, model, output_dir, logfile=None):
        file_list = sorted(os.listdir(self.input_dir))
        for i in range(len(file_list) - 1):
            in_filename, in_ext = re.split('\.', file_list[i])
            out_filename = in_filename + 'a.' + in_ext

            im1 = Image.open(self.input_dir + '/' + file_list[i])
            im2 = Image.open(self.input_dir + '/' + file_list[i + 1])
            img1 = to_variable(self.transform(im1).unsqueeze(0))
            img2 = to_variable(self.transform(im2).unsqueeze(0))

            frame_out = model(img1, img2)
            imwrite(frame_out, output_dir + '/' + out_filename)
            im1.close()
            im2.close()
Exemplo n.º 7
0
 def Test(self, model, output_dir, logfile=None, output_name='output.png'):
     av_psnr = 0
     func = lambda img: np.moveaxis(img.cpu().detach().numpy().squeeze(0),
                                    0, -1)
     if logfile is not None:
         logfile.write(
             '{:<7s}{:<3d}'.format('Epoch: ', model.epoch.item()) + '\n')
     for idx2 in range(len(self.im_list)):
         for idx in range(len(self.input0_list)):
             fig_size = plt.rcParams['figure.figsize']
             fig, ax = plt.subplots(1,
                                    2,
                                    figsize=(fig_size[0] * 2,
                                             fig_size[1] * 1))
             if not os.path.exists(output_dir + '/' + self.im_list[idx2]):
                 os.makedirs(output_dir + '/' + self.im_list[idx2])
             frame_out = model(self.input0_list[idx], self.input1_list[idx])
             gt = self.gt_list[idx]
             imwrite(frame_out,
                     output_dir + '/' + self.im_list[idx2] + '/' +
                     output_name,
                     range=(0, 1))
             frame_out = func(frame_out)
             gt = func(gt)
             _ = ax[1].imshow(frame_out)
             _ = ax[0].imshow(gt)
             _ = ax[0].set_title("Ground Truth")
             _ = ax[1].set_title("Predicted")
             plt.savefig(
                 os.path.join(output_dir, self.im_list[idx2],
                              f"{self.im_list[idx2]}_{idx}.png"))
             fig.clear()
             psnr = -10 * log10(np.mean(
                 (gt - frame_out) * (gt - frame_out)))
             av_psnr += psnr
             msg = '{:<15s}{:<20.16f}'.format(self.im_list[idx2] + ': ',
                                              psnr) + '\n'
             print(msg, end='')
             if logfile is not None:
                 logfile.write(msg)
         av_psnr /= len(self.input0_list)
         msg = '{:<15s}{:<20.16f}'.format('Average: ', av_psnr) + '\n'
         print(msg, end='')
         if logfile is not None:
             logfile.write(msg)
Exemplo n.º 8
0
def main():
    args = parse_args()
    torch.cuda.set_device(args.gpu_id)

    model = CDFI_adacof(args).cuda()

    print('Loading the model...')
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['state_dict'])

    frame_name1 = args.first_frame
    frame_name2 = args.second_frame

    transform = transforms.Compose([transforms.ToTensor()])
    frame1 = transform(Image.open(frame_name1)).unsqueeze(0).cuda()
    frame2 = transform(Image.open(frame_name2)).unsqueeze(0).cuda()

    model.eval()
    with torch.no_grad():
        frame_out = model(frame1, frame2)

    imwrite(frame_out.clone(), args.output_frame, range=(0, 1))
Exemplo n.º 9
0
def main():
    args = parser.parse_args()
    torch.cuda.set_device(args.gpu_id)

    config_file = open(args.config, 'r')
    while True:
        line = config_file.readline()
        if not line:
            break
        if line.find(':') == 0:
            continue
        else:
            tmp_list = line.split(': ')
            if tmp_list[0] == 'kernel_size':
                args.kernel_size = int(tmp_list[1])
            if tmp_list[0] == 'flow_num':
                args.flow_num = int(tmp_list[1])
            if tmp_list[0] == 'dilation':
                args.dilation = int(tmp_list[1])
    config_file.close()

    model = models.Model(args)

    print('Loading the model...')

    checkpoint = torch.load(args.checkpoint, map_location=torch.device('cpu'))
    model.load(checkpoint['state_dict'])

    base_dir = args.input_video

    if not os.path.exists(args.output_video):
        os.makedirs(args.output_video)

    frame_len = len([
        name for name in os.listdir(base_dir)
        if os.path.isfile(os.path.join(base_dir, name))
    ])

    for idx in range(frame_len - 1):
        idx += args.index_from
        print(idx, '/', frame_len - 1, end='\r')

        frame_name1 = base_dir + '/' + str(idx).zfill(args.zpad) + '.png'
        frame_name2 = base_dir + '/' + str(idx + 1).zfill(args.zpad) + '.png'

        frame1 = to_variable(transform(Image.open(frame_name1)).unsqueeze(0))
        frame2 = to_variable(transform(Image.open(frame_name2)).unsqueeze(0))

        model.eval()
        frame_out = model(frame1, frame2)

        # interpolate
        imwrite(frame1.clone(),
                args.output_video + '/' +
                str((idx - args.index_from) * 2 + args.index_from).zfill(
                    args.zpad) + '.png',
                range=(0, 1))
        imwrite(frame_out.clone(),
                args.output_video + '/' +
                str((idx - args.index_from) * 2 + 1 + args.index_from).zfill(
                    args.zpad) + '.png',
                range=(0, 1))

    # last frame
    print(frame_len - 1, '/', frame_len - 1)
    frame_name_last = base_dir + '/' + str(frame_len + args.index_from -
                                           1).zfill(args.zpad) + '.png'
    frame_last = to_variable(
        transform(Image.open(frame_name_last)).unsqueeze(0))
    imwrite(frame_last.clone(),
            args.output_video + '/' +
            str((frame_len - 1) * 2 + args.index_from).zfill(args.zpad) +
            '.png',
            range=(0, 1))
Exemplo n.º 10
0
    print('--- weight loaded ---')
except:
    print('--- no weight loaded ---')

# --- Strat testing --- #
with torch.no_grad():
    img_list = []
    time_list = []
    MyEnsembleNet.eval()
    imsave_dir = output_dir
    if not os.path.exists(imsave_dir):
        os.makedirs(imsave_dir)
    for batch_idx, hazy in enumerate(val_loader):
        # print(len(val_loader))
        start = time.time()
        hazy = hazy.to(device)

        img_tensor = MyEnsembleNet(hazy_up)

        end = time.time()
        time_list.append((end - start))
        img_list.append(img_tensor)

        imwrite(img_list[batch_idx],
                os.path.join(imsave_dir,
                             str(batch_idx) + '.png'))
    time_cost = float(sum(time_list) / len(time_list))
    print('running time per image: ', time_cost)

# writer.close()
Exemplo n.º 11
0
def main():
    args = parse_args()
    torch.cuda.set_device(args.gpu_id)

    transform = transforms.Compose([transforms.ToTensor()])

    model = CDFI_adacof(args).cuda()

    print('Loading the model...')
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['state_dict'])

    base_dir = args.input_video

    if not os.path.exists(args.output_video):
        os.makedirs(args.output_video)

    frame_len = len([
        name for name in os.listdir(base_dir)
        if os.path.isfile(os.path.join(base_dir, name))
    ])

    for idx in range(frame_len - 1):
        idx += args.index_from
        print(idx, '/', frame_len - 1, end='\r')

        frame_name1 = base_dir + '/' + str(idx).zfill(
            args.zpad) + args.img_format
        frame_name2 = base_dir + '/' + str(idx + 1).zfill(
            args.zpad) + args.img_format

        frame1 = transform(Image.open(frame_name1)).unsqueeze(0).cuda()
        frame2 = transform(Image.open(frame_name2)).unsqueeze(0).cuda()

        model.eval()
        with torch.no_grad():
            frame_out = model(frame1, frame2)

        # interpolate
        imwrite(frame1.clone(),
                args.output_video + '/' +
                str((idx - args.index_from) * 2 + args.index_from).zfill(
                    args.zpad) + args.img_format,
                range=(0, 1))
        imwrite(frame_out.clone(),
                args.output_video + '/' +
                str((idx - args.index_from) * 2 + 1 + args.index_from).zfill(
                    args.zpad) + args.img_format,
                range=(0, 1))

    # last frame
    print(frame_len - 1, '/', frame_len - 1)
    frame_name_last = base_dir + '/' + str(frame_len + args.index_from -
                                           1).zfill(
                                               args.zpad) + args.img_format
    frame_last = transform(Image.open(frame_name_last)).unsqueeze(0)
    imwrite(frame_last.clone(),
            args.output_video + '/' +
            str((frame_len - 1) * 2 + args.index_from).zfill(args.zpad) +
            args.img_format,
            range=(0, 1))
Exemplo n.º 12
0
            motion_vector_u = int(accumulated_motion_vectors[iiii, 0])
            motion_vector_v = int(accumulated_motion_vectors[iiii, 1])
            cropped_mask = large_mask_chain[iiii][:, :,
                                                  HHH + motion_vector_u:-HHH +
                                                  motion_vector_u,
                                                  WWW + motion_vector_v:-WWW +
                                                  motion_vector_v]
            print(motion_vector_u)
            print(motion_vector_v)
            print(cropped_mask.shape)
            """imwrite(output_frames[iiii][:, :, HHH+motion_vector_u:-HHH+motion_vector_u, WWW+motion_vector_v:-WWW+motion_vector_v], os.path.join(OUTPUT_PATH, avi_name, str(iiii+1).zfill(5)+'.png'), range=(0, 1))"""
            # if OOM
            imwrite(torch.from_numpy(np.load(
                output_frames[iiii]))[:, :, HHH + motion_vector_u:-HHH +
                                      motion_vector_u, WWW +
                                      motion_vector_v:-WWW + motion_vector_v],
                    os.path.join(OUTPUT_PATH,
                                 str(iiii + 1).zfill(5) + '.png'),
                    range=(0, 1))

            summed_mask = (torch.sum(1.0 - cropped_mask)).cpu().numpy()
            loss += summed_mask
        print(loss)

        # loss without adjustment
        loss = 0.0
        for iiii in range(len(large_mask_chain)):
            cropped_mask = large_mask_chain[iiii][:, :, HHH:-HHH, WWW:-WWW]
            summed_mask = (torch.sum(1.0 - cropped_mask)).cpu().numpy()
            loss += summed_mask
        print(loss)