Ejemplo n.º 1
0
def ToImage(frame0, frame1):

    with torch.no_grad():

        img0 = frame0.cuda()
        img1 = frame1.cuda()

        img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1)
        img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1)
        ref_imgt, _ = structure_gen((img0_e, img1_e))
        imgt = detail_enhance((img0, img1, ref_imgt))
        # imgt = detail_enhance((img0, img1, imgt))
        imgt = torch.clamp(imgt, max=1., min=-1.)

    return imgt
Ejemplo n.º 2
0
    for trainIndex, trainData in tqdm(enumerate(trainloader, 0)):

        # Getting the input and the target from the training set
        start_time = time.time()
        frame0, frameT, frame1 = trainData
        if args.test:
            """
            just for 1 batch test
            """
            frame0, frameT, frame1 = test_batch

        img0 = frame0.cuda()
        img1 = frame1.cuda()
        IFrame = frameT.cuda()

        img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1)
        img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1)
        IFrame_e = torch.cat([IFrame, torch.tanh(bdcn(IFrame)[0])], dim=1)

        if args.final:
            _, _, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e))
            _, _, _, ref_imgt2 = detail_enhance((img0, img1, IFrame, ref_imgt))
            loss, MSE_val, imgt = detail_enhance_last((img0, img1, IFrame, ref_imgt2))
        else:
            if args.GEN_DE:
                loss, _, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e))
            else:
                _, _, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e))
                loss, _, _, imgt = detail_enhance((img0, img1, IFrame, ref_imgt))
            # print(torch.max(torch.abs(ref_imgt - IFrame)))
Ejemplo n.º 3
0
def validate():
    # For details see training.
    psnr = 0
    ssim = 0
    tloss = 0
    flag = 1

    with torch.no_grad():
        for validationIndex, validationData in enumerate(validationloader, 0):
            frame0, frameT, frame1 = validationData

            img0 = frame0.cuda()
            img1 = frame1.cuda()
            IFrame = frameT.cuda()

            img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1)
            img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1)
            IFrame_e = torch.cat([IFrame, torch.tanh(bdcn(IFrame)[0])], dim=1)

            if args.final:
                _, _, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e))
                _, _, _, ref_imgt2 = detail_enhance((img0, img1, IFrame, ref_imgt))
                loss, MSE_val, imgt = detail_enhance_last((img0, img1, IFrame, ref_imgt2))
                SSIM = compare_ssim(IFrame.reshape(1, -1, 256, 448).squeeze(0).cpu().numpy().transpose(1, 2, 0),
                                    imgt.reshape(1, -1, 256, 448).squeeze(0).cpu().numpy().transpose(1, 2, 0),
                                    data_range=2, multichannel=True)
            else:
                if args.GEN_DE:
                    loss, MSE_val, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e))
                    SSIM = compare_ssim(IFrame.reshape(1, -1, 256, 448).squeeze(0).cpu().numpy().transpose(1, 2, 0),
                                        ref_imgt.reshape(1, -1, 256, 448).squeeze(0).cpu().numpy().transpose(1, 2, 0),
                                        data_range=2, multichannel=True)
                else:
                    _, _, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e))
                    loss, MSE_val, _, imgt = detail_enhance((img0, img1, IFrame, ref_imgt))
                    SSIM = compare_ssim(IFrame.reshape(1, -1, 256, 448).squeeze(0).cpu().numpy().transpose(1, 2, 0),
                                        imgt.reshape(1, -1, 256, 448).squeeze(0).cpu().numpy().transpose(1, 2, 0), data_range=2, multichannel=True)
            loss = torch.mean(loss)
            MSE_val = torch.mean(MSE_val)

            if (flag):
                if args.final:
                    retImg = torch.cat([revNormalize(frame0[0]).unsqueeze(0), revNormalize(frame1[0]).unsqueeze(0),
                                        revNormalize(imgt.cpu()[0]).unsqueeze(0), revNormalize(frameT[0]).unsqueeze(0),
                                        revNormalize(ref_imgt2.cpu()[0]).unsqueeze(0), revNormalize(ref_imgt.cpu()[0]).unsqueeze(0)],
                                       dim=0)
                else:
                    if args.GEN_DE:
                        retImg = torch.cat([revNormalize(frame0[0]).unsqueeze(0), revNormalize(frame1[0]).unsqueeze(0),
                                            revNormalize(ref_imgt.cpu()[0]).unsqueeze(0), revNormalize(frameT[0]).unsqueeze(0)],
                                           dim=0)
                    else:
                        retImg = torch.cat([revNormalize(frame0[0]).unsqueeze(0), revNormalize(frame1[0]).unsqueeze(0),
                                            revNormalize(imgt.cpu()[0]).unsqueeze(0), revNormalize(frameT[0]).unsqueeze(0),
                                            revNormalize(ref_imgt.cpu()[0]).unsqueeze(0)],
                                           dim=0)
                flag = 0

            # psnr
            tloss += loss.item()

            psnr += (10 * log10(4 / MSE_val.item()))
            ssim += SSIM

    return (psnr / len(validationloader)), (tloss / len(validationloader)), retImg, MSE_val, (ssim / len(validationloader))
Ejemplo n.º 4
0
def validate():
    # For details see training.
    psnr = 0
    ie = 0
    tloss = 0

    with torch.no_grad():
        for testIndex, testData in tqdm(enumerate(testloader, 0)):
            frame0, frameT, frame1 = testData

            img0 = frame0.cuda()
            img1 = frame1.cuda()
            IFrame = frameT.cuda()

            img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1)
            img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1)
            IFrame_e = torch.cat([IFrame, torch.tanh(bdcn(IFrame)[0])], dim=1)
            _, _, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e))
            loss, MSE_val, IE, imgt = detail_enhance(
                (img0, img1, IFrame, ref_imgt))
            imgt = torch.clamp(imgt, max=1., min=-1.)
            IFrame_np = IFrame.squeeze(0).cpu().numpy()
            imgt_np = imgt.squeeze(0).cpu().numpy()
            imgt_png = np.uint8(
                ((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
            IFrame_png = np.uint8(
                ((IFrame_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
            imgpath = args.imgpath + '/' + str(testIndex)
            if not os.path.isdir(imgpath):
                os.system('mkdir -p %s' % imgpath)
            cv2.imwrite(imgpath + '/imgt.png', imgt_png)
            cv2.imwrite(imgpath + '/IFrame.png', IFrame_png)

            PSNR = compare_psnr(IFrame_np, imgt_np, data_range=2)
            print('PSNR:', PSNR)

            loss = torch.mean(loss)
            MSE_val = torch.mean(MSE_val)

            if testIndex % 100 == 99:
                vImg = torch.cat([
                    revNormalize(frame0[0]).unsqueeze(0),
                    revNormalize(frame1[0]).unsqueeze(0),
                    revNormalize(imgt.cpu()[0]).unsqueeze(0),
                    revNormalize(frameT[0]).unsqueeze(0),
                    revNormalize(ref_imgt.cpu()[0]).unsqueeze(0)
                ],
                                 dim=0)

                vImg = torch.clamp(vImg, max=1., min=0)
                vis.images(vImg,
                           win='vImage',
                           env=args.visdom_env,
                           nrow=2,
                           opts={'title': 'visual_image'})

            # psnr
            tloss += loss.item()

            psnr += PSNR
            ie += IE

    return (psnr /
            len(testloader)), (tloss /
                               len(testloader)), MSE_val, (ie /
                                                           len(testloader))