def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Model() model.load_model('../train_log/HDv2', -1) model.eval() model.device() path = '../../data/UCF101/ucf101_interp_ours/' dirs = os.listdir(path) psnr_list = [] ssim_list = [] time_list = [] # print('=========>Start Calculate PSNR and SSIM') for d in tqdm(dirs): img0 = (path + d + '/frame_00.png') img1 = (path + d + '/frame_02.png') gt = (path + d + '/frame_01_gt.png') img0 = (torch.tensor(cv2.imread(img0).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0) img1 = (torch.tensor(cv2.imread(img1).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0) gt = (torch.tensor(cv2.imread(gt).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0) # Avg PSNR: 35.243666269214145 SSIM: 0.9683315742368154 Time: 0.133457749911717 # inference pred = model.inference(img0, img1)[0] # Calculate indicators out = pred.detach().cpu().numpy().transpose(1, 2, 0) out = np.round(out * 255) / 255. gt = gt[0].cpu().numpy().transpose(1, 2, 0) psnr = compute_psnr(gt, out) ssim = compute_ssim(gt, out) psnr_list.append(psnr) ssim_list.append(ssim) # print("Avg PSNR: {} SSIM: {}".format(np.mean(psnr_list), np.mean(ssim_list))) # print('=========>Start Calculate Inference Time') # inference time for i in range(100): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() pred = model.inference(img0, img1)[0] end.record() torch.cuda.synchronize() time_list.append(start.elapsed_time(end)) time_list.remove(min(time_list)) time_list.remove(max(time_list)) print("Avg PSNR: {} SSIM: {} Time: {}".format(np.mean(psnr_list), np.mean(ssim_list), np.mean(time_list) / 100))
from model.RIFE_HDv2 import Model model = Model() model.load_model(args.modelDir, -1) print("Loaded v2.x HD model.") except: from train_log.RIFE_HDv3 import Model model = Model() model.load_model(args.modelDir, -1) print("Loaded v3.x HD model.") except: from model.RIFE_HD import Model model = Model() model.load_model(args.modelDir, -1) print("Loaded v1.x HD model") model.eval() model.device() if args.img[0].endswith('.exr') and args.img[1].endswith('.exr'): img0 = cv2.imread(args.img[0], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) img1 = cv2.imread(args.img[1], cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH) img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device)).unsqueeze(0) img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device)).unsqueeze(0) else: img0 = cv2.imread(args.img[0]) img1 = cv2.imread(args.img[1]) img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)