def main():
    model = Model()
    loader = model.load_model('./train_log')
    model.eval()
    model.device()

    fixargs(args)
    return
    img0 = cv2.imread(args.img[0])
    img1 = cv2.imread(args.img[1])
    img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) /
            255.).unsqueeze(0)

    img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) /
            255.).unsqueeze(0)
    n, c, h, w = img0.shape
    ph = ((h - 1) // 32 + 1) * 32
    pw = ((w - 1) // 32 + 1) * 32
    padding = (0, pw - w, 0, ph - h)
    img0 = F.pad(img0, padding)
    img1 = F.pad(img1, padding)

    img_list = [img0, img1]
    for i in range(args.exp):
        tmp = []
        for j in range(len(img_list) - 1):
            mid = model.inference(img_list[j], img_list[j + 1])
            tmp.append(img_list[j])
            tmp.append(mid)
        tmp.append(img1)
        img_list = tmp

    if not os.path.exists('output'):
        os.mkdir('output')
    for i in range(len(img_list)):
        cv2.imwrite('output/img{}.png'.format(i),
                    (img_list[i][0] * 255).byte().cpu().numpy().transpose(
                        1, 2, 0)[:h, :w])
Exemplo n.º 2
0
model = Model()
model.load_model('train_log')
model.eval()
model.device()

name = [
    'Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea',
    'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking'
]
IE_list = []
for i in name:
    i0 = cv2.imread('other-data/{}/frame10.png'.format(i)).transpose(2, 0,
                                                                     1) / 255.
    i1 = cv2.imread('other-data/{}/frame11.png'.format(i)).transpose(2, 0,
                                                                     1) / 255.
    gt = cv2.imread('other-gt-interp/{}/frame10i11.png'.format(i))
    h, w = i0.shape[1], i0.shape[2]
    imgs = torch.zeros([1, 6, 480, 640]).to(device)
    ph = (480 - h) // 2
    pw = (640 - w) // 2
    imgs[:, :3, :h, :w] = torch.from_numpy(i0).unsqueeze(0).float().to(device)
    imgs[:, 3:, :h, :w] = torch.from_numpy(i1).unsqueeze(0).float().to(device)
    I0 = imgs[:, :3]
    I2 = imgs[:, 3:]
    pred = model.inference(I0, I2)
    out = pred[0].detach().cpu().numpy().transpose(1, 2, 0)
    out = np.round(out[:h, :w] * 255)
    IE_list.append(np.abs((out - gt * 1.0)).mean())
    print(np.mean(IE_list))
Exemplo n.º 3
0
model = Model()
model.load_model('./train_log')
model.eval()
model.device()

path = 'vimeo_interp_test/'
f = open(path + 'tri_testlist.txt', 'r')
psnr_list = []
ssim_list = []
for i in f:
    name = str(i).strip()
    if (len(name) <= 1):
        continue
    print(path + 'target/' + name + '/im1.png')
    I0 = cv2.imread(path + 'target/' + name + '/im1.png')
    I1 = cv2.imread(path + 'target/' + name + '/im2.png')
    I2 = cv2.imread(path + 'target/' + name + '/im3.png')
    I0 = (torch.tensor(I0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
    I2 = (torch.tensor(I2.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
    mid = model.inference(I0, I2)[0]
    ssim = ssim_matlab(
        torch.tensor(I1.transpose(2, 0, 1)).to(device).unsqueeze(0) / 255.,
        mid.unsqueeze(0)).cpu().numpy()
    mid = np.round(
        (mid * 255).cpu().numpy()).astype('uint8').transpose(1, 2, 0) / 255.
    I1 = I1 / 255.
    psnr = -10 * math.log10(((I1 - mid) * (I1 - mid)).mean())
    psnr_list.append(psnr)
    ssim_list.append(ssim)
    print(np.mean(psnr_list), np.mean(ssim_list))
Exemplo n.º 4
0
import cv2
import sys
sys.path.append('.')
import time
import torch
import torch.nn as nn
from model.RIFE import Model

model = Model()
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
if torch.cuda.is_available():
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
       
I0 = torch.rand(1, 3, 480, 640).to(device)
I1 = torch.rand(1, 3, 480, 640).to(device)
with torch.no_grad():
    for i in range(100):
        pred = model.inference(I0, I1)
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    time_stamp = time.time()
    for i in range(100):
        pred = model.inference(I0, I1)
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    print((time.time() - time_stamp) / 100)
      F.interpolate(I1, (16, 16), mode='bilinear',
                    align_corners=False)).abs().mean()
 if p < 5e-3 and args.skip:
     if skip_frame % 100 == 0:
         print(
             "Warning: Your video has {} static frames, skipping them may change the duration of the generated video."
             .format(skip_frame))
     skip_frame += 1
     pbar.update(1)
     continue
 if p > 0.2:
     mid1 = lastframe
     mid0 = lastframe
     mid2 = frame
 else:
     mid1 = model.inference(I0, I1)
     if args.exp == 4:
         mid = model.inference(torch.cat((I0, mid1), 0),
                               torch.cat((mid1, I1), 0))
     mid1 = ((mid1[0] * 255.).byte().cpu().detach().numpy().transpose(
         1, 2, 0))
     if args.exp == 4:
         mid0 = ((mid[0] * 255.).byte().cpu().detach().numpy().transpose(
             1, 2, 0))
         mid2 = ((mid[1] * 255.).byte().cpu().detach().numpy().transpose(
             1, 2, 0))
 if args.montage:
     buffer.put(np.concatenate((lastframe, lastframe), 1))
     if args.exp == 4:
         buffer.put(np.concatenate((lastframe, mid0[:h, :w]), 1))
     buffer.put(np.concatenate((lastframe, mid1[:h, :w]), 1))
Exemplo n.º 6
0
model.eval()
model.device()

path = 'UCF101/ucf101_interp_ours/'
dirs = os.listdir(path)
psnr_list = []
ssim_list = []
print(len(dirs))
for d in dirs:
    img0 = (path + d + '/frame_00.png')
    img1 = (path + d + '/frame_02.png')
    gt = (path + d + '/frame_01_gt.png')
    img0 = (torch.tensor(cv2.imread(img0).transpose(2, 0, 1) /
                         255.)).to(device).float().unsqueeze(0)
    img1 = (torch.tensor(cv2.imread(img1).transpose(2, 0, 1) /
                         255.)).to(device).float().unsqueeze(0)
    gt = (torch.tensor(cv2.imread(gt).transpose(2, 0, 1) /
                       255.)).to(device).float().unsqueeze(0)
    pred = model.inference(img0, img1)[0]
    ssim = ssim_matlab(gt,
                       torch.round(pred * 255).unsqueeze(0) /
                       255.).detach().cpu().numpy()
    out = pred.detach().cpu().numpy().transpose(1, 2, 0)
    out = np.round(out * 255) / 255.
    gt = gt[0].cpu().numpy().transpose(1, 2, 0)
    psnr = -10 * math.log10(((gt - out) * (gt - out)).mean())
    psnr_list.append(psnr)
    ssim_list.append(ssim)
    print("Avg PSNR: {} SSIM: {}".format(np.mean(psnr_list),
                                         np.mean(ssim_list)))
Exemplo n.º 7
0
img1 = cv2.imread(args.img[1])

img0 = (torch.tensor(img0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
img1 = (torch.tensor(img1.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
n, c, h, w = img0.shape
ph = ((h - 1) // 32 + 1) * 32
pw = ((w - 1) // 32 + 1) * 32
padding = (0, pw - w, 0, ph - h)
img0 = F.pad(img0, padding)
img1 = F.pad(img1, padding)

img_list = [img0, img1]
for i in range(args.times):
    tmp = []
    for j in range(len(img_list) - 1):
        mid = model.inference(img_list[j], img_list[j + 1])
        tmp.append(img_list[j])
        tmp.append(mid)
    tmp.append(img1)
    img_list = tmp

if not os.path.exists('output'):
    os.mkdir('output')
for i in range(len(img_list)):
    #     cv2.imwrite('output/img{}.png'.format(i), img_list[i][0].numpy().transpose(1, 2, 0)[:h, :w] * 255)
    #     cv2.imwrite('output/img{}.png'.format(i), torch.Tensor.cpu(img_list[i][0]).detach().numpy().transpose(1, 2, 0)[:h, :w] * 255)
    #     cv2.imwrite('output/img{}.png'.format(i), torch.Tensor.detach(img_list[i][0]).cpu().numpy().transpose(1, 2, 0)[:h, :w] * 255)
    var = img_list[i][0]
    cv2.imwrite('output/img{}.png'.format(i),
                var.detach().cpu().numpy().transpose(1, 2, 0)[:h, :w] * 255)
Exemplo n.º 8
0
path = 'datasets/test_2k_540p/'
dirs = os.listdir(path)
psnr_list = []
ssim_list = []
print(len(dirs))
for d in dirs:
    img0 = (path + d + '/frame1.png')
    img1 = (path + d + '/frame3.png')
    gt = (path + d + '/frame2.png')
    img0 = (torch.tensor(cv2.imread(img0).transpose(2, 0, 1) /
                         255.)).to(device).float().unsqueeze(0)
    img1 = (torch.tensor(cv2.imread(img1).transpose(2, 0, 1) /
                         255.)).to(device).float().unsqueeze(0)
    gt = (torch.tensor(cv2.imread(gt).transpose(2, 0, 1) /
                       255.)).to(device).float().unsqueeze(0)
    pader = torch.nn.ReplicationPad2d([0, 0, 2, 2])
    img0 = pader(img0)
    img1 = pader(img1)
    pred = model.inference(img0, img1)[0][:, 2:-2]
    ssim = ssim_matlab(gt,
                       torch.round(pred * 255).unsqueeze(0) /
                       255.).detach().cpu().numpy()
    out = pred.detach().cpu().numpy().transpose(1, 2, 0)
    out = np.round(out * 255) / 255.
    gt = gt[0].cpu().numpy().transpose(1, 2, 0)
    psnr = -10 * math.log10(((gt - out) * (gt - out)).mean())
    psnr_list.append(psnr)
    ssim_list.append(ssim)
    print("Avg PSNR: {} SSIM: {}".format(np.mean(psnr_list),
                                         np.mean(ssim_list)))