Exemplo n.º 1
0
def main():
    args = parse_args()
    torch.cuda.set_device(args.gpu_id)

    model = CDFI_adacof(args).cuda()
    print('===============================')
    print("# of model parameters is: " + str(count_network_parameters(model)))

    print('Loading the model...')
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['state_dict'])

    print('===============================')
    print('Test: Middlebury_others')
    test_dir = args.out_dir + '/middlebury_others'
    if not os.path.exists(test_dir):
        os.makedirs(test_dir)
    test_db = Middlebury_other('./test_data/middlebury_others/input',
                               './test_data/middlebury_others/gt')
    test_db.test(model, test_dir)

    print('===============================')
    print('Test: UCF101-DVF')
    test_dir = args.out_dir + '/ucf101-dvf'
    if not os.path.exists(test_dir):
        os.makedirs(test_dir)
    test_db = ucf_dvf('./test_data/ucf101_interp_ours')
    test_db.test(model, test_dir)
Exemplo n.º 2
0
def main():
    args = parse_args()
    torch.cuda.set_device(args.gpu_id)

    # prepare training data
    train_dataset, val_dataset = Vimeo90K_interp(args.data_dir)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=8)
    # val_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8)

    # prepare test data
    test_db = Middlebury_other(args.test_input, args.test_gt)

    # initialize our model
    model = CDFI_adacof(args).cuda()
    print("# of model parameters is: " +
          str(utility.count_network_parameters(model)))

    # prepare the loss
    loss = Loss(args)

    # prepare the trainer
    my_trainer = Trainer(args, train_loader, test_db, model, loss)

    # start training
    while not my_trainer.terminate():
        my_trainer.train()
        my_trainer.test()

    my_trainer.close()
Exemplo n.º 3
0
def main():
    args = parse_args()
    torch.cuda.set_device(args.gpu_id)

    model = CDFI_adacof(args).cuda()

    print('Loading the model...')
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['state_dict'])

    frame_name1 = args.first_frame
    frame_name2 = args.second_frame

    transform = transforms.Compose([transforms.ToTensor()])
    frame1 = transform(Image.open(frame_name1)).unsqueeze(0).cuda()
    frame2 = transform(Image.open(frame_name2)).unsqueeze(0).cuda()

    model.eval()
    with torch.no_grad():
        frame_out = model(frame1, frame2)

    imwrite(frame_out.clone(), args.output_frame, range=(0, 1))
    if(args.fp16):
        torch.set_default_tensor_type(torch.cuda.HalfTensor)

# try:
#     from model.RIFE_HDv2 import Model
#     model = Model()
#     model.load_model(args.modelDir, -1)
#     print("Loaded v2.x HD model.")
# except:
#     from model.RIFE_HD import Model
#     model = Model()
#     model.load_model(args.modelDir, -1)
#     print("Loaded v1.x HD model")
from models.cdfi_adacof import CDFI_adacof

model = CDFI_adacof(kernel_size=11, dilation=2)
checkpoint = torch.load('checkpoints/CDFI_adacof.pth')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
model.to(device)

if not args.video is None:
    videoCapture = cv2.VideoCapture(args.video)  # 读视频
    fps = videoCapture.get(cv2.CAP_PROP_FPS)  # 获取视频的fps
    tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)  # 获取视频的total frames
    videoCapture.release()  # 释放视频
    if args.fps is None:
        fpsNotAssigned = True
        args.fps = fps * (2 ** args.exp)  # fps * (2**1), fps * (2**2), fps * (2**3)
    else:
        fpsNotAssigned = False
def main():

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CDFI_adacof(kernel_size=11, dilation=2)
    checkpoint = torch.load('../checkpoints/CDFI_adacof.pth')
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    model.to(device)

    path = '../../data/UCF101/ucf101_interp_ours/'
    # path = '/data/codec/zhangdy/video_interpolation/source_data/ucf101_interp_ours/'
    dirs = os.listdir(path)

    psnr_list = []
    ssim_list = []
    time_list = []
    # print('=========>Start Calculate PSNR and SSIM')
    for d in tqdm(dirs):
        img0 = (path + d + '/frame_00.png')
        img1 = (path + d + '/frame_02.png')
        gt = (path + d + '/frame_01_gt.png')
        # img0 = (torch.tensor(Image.open(img0).unsqueeze(0).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0)
        # img1 = (torch.tensor(Image.open(img1).unsqueeze(0).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0)
        # gt = (torch.tensor(Image.open(gt).unsqueeze(0).transpose(2, 0, 1) / 255.)).to(device).float().unsqueeze(0)
        img0 = Image.open(img0)
        img1 = Image.open(img1)
        gt = Image.open(gt)

        img0 = transform(img0).unsqueeze(0).cuda()
        img1 = transform(img1).unsqueeze(0).cuda()
        gt = transform(gt).unsqueeze(0).cuda()
        if img1.size(1) == 1:
            img0 = img0.expand(-1, 3, -1, -1)
            img1 = img1.expand(-1, 3, -1, -1)
        # inference
        pred = model(img0, img1)[0]
        pred = torch.clamp(pred, 0, 1)
        # Calculate indicators
        out = pred.detach().cpu().numpy().transpose(1, 2, 0)
        out = np.round(out * 255) / 255.
        gt = gt[0].cpu().numpy().transpose(1, 2, 0)
        print(gt.max(), out.max())
        psnr = compute_psnr(gt, out)
        ssim = compute_ssim(gt, out)
        psnr_list.append(psnr)
        ssim_list.append(ssim)
    # print("Avg PSNR: {} SSIM: {}".format(np.mean(psnr_list), np.mean(ssim_list)))
    # print('=========>Start Calculate Inference Time')

    # inference time
    for i in range(100):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        pred = model(img0, img1)[0]
        end.record()
        torch.cuda.synchronize()
        time_list.append(start.elapsed_time(end))
    time_list.remove(min(time_list))
    time_list.remove(max(time_list))
    print("Avg PSNR: {} SSIM: {} Time: {}".format(np.mean(psnr_list),
                                                  np.mean(ssim_list),
                                                  np.mean(time_list) / 100))
Exemplo n.º 6
0
def main():
    args = parse_args()
    torch.cuda.set_device(args.gpu_id)

    transform = transforms.Compose([transforms.ToTensor()])

    model = CDFI_adacof(args).cuda()

    print('Loading the model...')
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['state_dict'])

    base_dir = args.input_video

    if not os.path.exists(args.output_video):
        os.makedirs(args.output_video)

    frame_len = len([
        name for name in os.listdir(base_dir)
        if os.path.isfile(os.path.join(base_dir, name))
    ])

    for idx in range(frame_len - 1):
        idx += args.index_from
        print(idx, '/', frame_len - 1, end='\r')

        frame_name1 = base_dir + '/' + str(idx).zfill(
            args.zpad) + args.img_format
        frame_name2 = base_dir + '/' + str(idx + 1).zfill(
            args.zpad) + args.img_format

        frame1 = transform(Image.open(frame_name1)).unsqueeze(0).cuda()
        frame2 = transform(Image.open(frame_name2)).unsqueeze(0).cuda()

        model.eval()
        with torch.no_grad():
            frame_out = model(frame1, frame2)

        # interpolate
        imwrite(frame1.clone(),
                args.output_video + '/' +
                str((idx - args.index_from) * 2 + args.index_from).zfill(
                    args.zpad) + args.img_format,
                range=(0, 1))
        imwrite(frame_out.clone(),
                args.output_video + '/' +
                str((idx - args.index_from) * 2 + 1 + args.index_from).zfill(
                    args.zpad) + args.img_format,
                range=(0, 1))

    # last frame
    print(frame_len - 1, '/', frame_len - 1)
    frame_name_last = base_dir + '/' + str(frame_len + args.index_from -
                                           1).zfill(
                                               args.zpad) + args.img_format
    frame_last = transform(Image.open(frame_name_last)).unsqueeze(0)
    imwrite(frame_last.clone(),
            args.output_video + '/' +
            str((frame_len - 1) * 2 + args.index_from).zfill(args.zpad) +
            args.img_format,
            range=(0, 1))