def validate(I0, I1, index, length, device):
    mean = [0.429, 0.431, 0.397]
    std  = [1, 1, 1]
    if index == 0:
        return I0
    else:
        with torch.no_grad():
            I0 = utils.meanshift(I0, mean, std, device, True)
            I1 = utils.meanshift(I1, mean, std, device, True)
            flowOut = flowComp(torch.cat((I0, I1), dim=1))
            F_0_1 = flowOut[:,:2,:,:]
            F_1_0 = flowOut[:,2:,:,:]

            fCoeff = superslomo.getFlowCoeff(index-1, device, length+1)

            F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
            F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

            g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
            g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)
            
            intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1))
                
            F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
            F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
            V_t_0   = torch.sigmoid(intrpOut[:, 4:5, :, :])
            V_t_1   = 1 - V_t_0
                
            g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
            g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)
            
            wCoeff = superslomo.getWarpCoeff(index-1, device, length+1)
            
            Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)
            Ft_p = utils.meanshift(Ft_p, mean, std, device, False)

        return Ft_p
Exemple #2
0
def test():
    df_column = ['Name']
    df_column.extend([str(i) for i in range(1, seq_len + 1)])

    df = pd.DataFrame(columns=df_column)

    psnr_array = np.zeros((0, seq_len))
    ssim_array = np.zeros((0, seq_len))

    tqdm_loader = tqdm.tqdm(validationloader, ncols=80)

    imgsave_folder = os.path.join(args.checkpoint_dir, 'Saved_imgs')
    if not os.path.exists(imgsave_folder):
        os.mkdir(imgsave_folder)

    with torch.no_grad():
        for validationIndex, (validationData, validationFrameIndex,
                              validationFile) in enumerate(tqdm_loader):

            blurred_img = torch.zeros_like(validationData[0])
            for image in validationData:
                blurred_img += image
            blurred_img /= len(validationData)
            blurred_img = blurred_img.to(device)
            batch_size = blurred_img.shape[0]

            blurred_img = meanshift(blurred_img, mean, std, device, False)
            c = center_estimation(blurred_img)
            start, end = border_estimation(blurred_img, c)
            start = meanshift(start, mean, std, device, True)
            end = meanshift(end, mean, std, device, True)
            blurred_img = meanshift(blurred_img, mean, std, device, True)

            frame0 = validationData[0].to(device)
            frame1 = validationData[-1].to(device)

            batch_size = blurred_img.shape[0]
            parallel = torch.mean(compare_ftn(start, frame0) +
                                  compare_ftn(end, frame1),
                                  dim=(1, 2, 3))
            cross = torch.mean(compare_ftn(start, frame1) +
                               compare_ftn(end, frame0),
                               dim=(1, 2, 3))

            I0 = torch.zeros_like(blurred_img)
            I1 = torch.zeros_like(blurred_img)
            for b in range(batch_size):
                if parallel[b] <= cross[b]:
                    I0[b], I1[b] = start[b], end[b]
                else:
                    I0[b], I1[b] = end[b], start[b]

            psnrs = np.zeros((batch_size, seq_len))
            ssims = np.zeros((batch_size, seq_len))

            for vindex in range(seq_len):
                frameT = validationData[vindex]
                IFrame = frameT.to(device)

                if vindex == 0:
                    Ft_p = I0.clone()

                elif vindex == seq_len - 1:
                    Ft_p = I1.clone()

                else:
                    validationIndex = torch.ones(batch_size) * (vindex - 1)
                    validationIndex = validationIndex.long()
                    flowOut = flowComp(torch.cat((I0, I1), dim=1))
                    F_0_1 = flowOut[:, :2, :, :]
                    F_1_0 = flowOut[:, 2:, :, :]

                    fCoeff = superslomo.getFlowCoeff(validationIndex, device,
                                                     seq_len)

                    F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                    F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                    g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
                    g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)

                    if args.add_blur:
                        intrpOut = ArbTimeFlowIntrp(
                            torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                       g_I1_F_t_1, g_I0_F_t_0, blurred_img),
                                      dim=1))
                    else:
                        intrpOut = ArbTimeFlowIntrp(
                            torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                       g_I1_F_t_1, g_I0_F_t_0),
                                      dim=1))

                    F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                    F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                    V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
                    V_t_1 = 1 - V_t_0

                    g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
                    g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)

                    wCoeff = superslomo.getWarpCoeff(validationIndex, device,
                                                     seq_len)

                    Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f +
                            wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (
                                wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                Ft_p = meanshift(Ft_p, mean, std, device, False)
                IFrame = meanshift(IFrame, mean, std, device, False)

                for b in range(batch_size):
                    foldername = os.path.basename(
                        os.path.dirname(validationFile[ctr_idx][b]))
                    filename = os.path.splitext(
                        os.path.basename(validationFile[vindex][b]))[0]

                    out_fname = foldername + '_' + filename + '_out.png'
                    gt_fname = foldername + '_' + filename + '.png'
                    out, gt = quantize(Ft_p[b]), quantize(IFrame[b])

                    # Comment two lines below if you want to save images
                    # torchvision.utils.save_image(out, os.path.join(imgsave_folder, out_fname), normalize=True, range=(0,255))
                    # torchvision.utils.save_image(gt, os.path.join(imgsave_folder, gt_fname), normalize=True, range=(0,255))

                psnr, ssim = eval_metrics(Ft_p, IFrame)
                psnrs[:, vindex] = psnr.cpu().numpy()
                ssims[:, vindex] = ssim.cpu().numpy()

            for b in range(batch_size):
                rows = [validationFile[ctr_idx][b]]
                rows.extend(list(psnrs[b]))
                df = df.append(pd.Series(rows, index=df.columns),
                               ignore_index=True)

            df.to_csv('{}/results_PSNR.csv'.format(args.checkpoint_dir))
Exemple #3
0
def validate():
    # For details see training.
    psnr = 0
    tloss = 0
    flag = 1
    with torch.no_grad():
        for validationIndex, (validationData, validationFrameIndex,
                              _) in enumerate(validationloader, 0):
            blurred_img = torch.zeros_like(validationData[0])
            for image in validationData:
                blurred_img += image
            blurred_img /= len(validationData)
            blurred_img = blurred_img.to(device)

            blurred_img = meanshift(blurred_img, mean, std, device, False)
            c = center_estimation(blurred_img)
            start, end = border_estimation(blurred_img, c)
            start = meanshift(start, mean, std, device, True)
            end = meanshift(end, mean, std, device, True)
            blurred_img = meanshift(blurred_img, mean, std, device, True)

            frame0 = validationData[0].to(device)
            frame1 = validationData[-1].to(device)

            batch_size = blurred_img.shape[0]
            parallel = torch.mean(compare_ftn(start, frame0) +
                                  compare_ftn(end, frame1),
                                  dim=(1, 2, 3))
            cross = torch.mean(compare_ftn(start, frame1) +
                               compare_ftn(end, frame0),
                               dim=(1, 2, 3))

            I0 = torch.zeros_like(blurred_img)
            I1 = torch.zeros_like(blurred_img)
            IFrame = torch.zeros_like(blurred_img)
            for b in range(batch_size):
                if parallel[b] <= cross[b]:
                    I0[b], I1[b] = start[b], end[b]
                else:
                    I0[b], I1[b] = end[b], start[b]

                IFrame[b] = validationData[validationFrameIndex[b] + 1][b]

            if args.amp:
                with torch.cuda.amp.autocast():
                    flowOut = flowComp(torch.cat((I0, I1), dim=1))
                    F_0_1 = flowOut[:, :2, :, :]
                    F_1_0 = flowOut[:, 2:, :, :]

                    fCoeff = superslomo.getFlowCoeff(validationFrameIndex,
                                                     device, seq_len)

                    F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                    F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                    g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
                    g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)

                    if args.add_blur:
                        intrpOut = ArbTimeFlowIntrp(
                            torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                       g_I1_F_t_1, g_I0_F_t_0, blurred_img),
                                      dim=1))
                    else:
                        intrpOut = ArbTimeFlowIntrp(
                            torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                       g_I1_F_t_1, g_I0_F_t_0),
                                      dim=1))

                    F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                    F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                    V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
                    V_t_1 = 1 - V_t_0

                    g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
                    g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)

                    wCoeff = superslomo.getWarpCoeff(validationFrameIndex,
                                                     device, seq_len)

                    Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f +
                            wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (
                                wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                    #loss
                    recnLoss = L1_lossFn(Ft_p, IFrame)
                    prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p),
                                          vgg16_conv_4_3(IFrame))
                    warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                        g_I1_F_t_1, IFrame) + L1_lossFn(
                            validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                                validationFlowBackWarp(I1, F_0_1), I0)

                    loss_smooth_1_0 = torch.mean(
                        torch.abs(F_1_0[:, :, :, :-1] -
                                  F_1_0[:, :, :, 1:])) + torch.mean(
                                      torch.abs(F_1_0[:, :, :-1, :] -
                                                F_1_0[:, :, 1:, :]))
                    loss_smooth_0_1 = torch.mean(
                        torch.abs(F_0_1[:, :, :, :-1] -
                                  F_0_1[:, :, :, 1:])) + torch.mean(
                                      torch.abs(F_0_1[:, :, :-1, :] -
                                                F_0_1[:, :, 1:, :]))
                    loss_smooth = loss_smooth_1_0 + loss_smooth_0_1

                    loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth

            else:
                flowOut = flowComp(torch.cat((I0, I1), dim=1))
                F_0_1 = flowOut[:, :2, :, :]
                F_1_0 = flowOut[:, 2:, :, :]

                fCoeff = superslomo.getFlowCoeff(validationFrameIndex, device,
                                                 seq_len)

                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)

                if args.add_blur:
                    intrpOut = ArbTimeFlowIntrp(
                        torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                   g_I1_F_t_1, g_I0_F_t_0, blurred_img),
                                  dim=1))
                else:
                    intrpOut = ArbTimeFlowIntrp(
                        torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                   g_I1_F_t_1, g_I0_F_t_0),
                                  dim=1))

                F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
                V_t_1 = 1 - V_t_0

                g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
                g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)

                wCoeff = superslomo.getWarpCoeff(validationFrameIndex, device,
                                                 seq_len)

                Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                        g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                #loss
                recnLoss = L1_lossFn(Ft_p, IFrame)
                prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p),
                                      vgg16_conv_4_3(IFrame))
                warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                    g_I1_F_t_1, IFrame) + L1_lossFn(
                        validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                            validationFlowBackWarp(I1, F_0_1), I0)

                loss_smooth_1_0 = torch.mean(
                    torch.abs(F_1_0[:, :, :, :-1] -
                              F_1_0[:, :, :, 1:])) + torch.mean(
                                  torch.abs(F_1_0[:, :, :-1, :] -
                                            F_1_0[:, :, 1:, :]))
                loss_smooth_0_1 = torch.mean(
                    torch.abs(F_0_1[:, :, :, :-1] -
                              F_0_1[:, :, :, 1:])) + torch.mean(
                                  torch.abs(F_0_1[:, :, :-1, :] -
                                            F_0_1[:, :, 1:, :]))
                loss_smooth = loss_smooth_1_0 + loss_smooth_0_1

                loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth

            # For tensorboard
            if flag:
                retImg = torchvision.utils.make_grid([
                    revNormalize(frame0.cpu()[0]),
                    revNormalize(IFrame.cpu()[0]),
                    revNormalize(Ft_p.cpu()[0]),
                    revNormalize(frame1.cpu()[0])
                ],
                                                     padding=10)
                flag = 0

            tloss += loss.item()
            #psnr
            MSE_val = MSE_LossFn(Ft_p, IFrame)
            psnr += (10 * log10(1 / MSE_val.item()))

            # Make benchmark csv file

    return (psnr / len(validationloader)), (tloss /
                                            len(validationloader)), retImg
Exemple #4
0
    iLoss = 0

    tqdm_trainloader = tqdm.tqdm(trainloader, ncols=80)

    for trainIndex, (trainData, trainFrameIndex,
                     _) in enumerate(tqdm_trainloader):
        ## Getting the input and the target from the training set
        # frame0, frameT, frame1 = trainData
        blurred_img = torch.zeros_like(trainData[0])
        for image in trainData:
            blurred_img += image
        blurred_img /= len(trainData)
        blurred_img = blurred_img.to(device)

        with torch.no_grad():
            blurred_img = meanshift(blurred_img, mean, std, device, False)
            c = center_estimation(blurred_img)
            start, end = border_estimation(blurred_img, c)
            start = meanshift(start, mean, std, device, True)
            end = meanshift(end, mean, std, device, True)
            blurred_img = meanshift(blurred_img, mean, std, device, True)

        frame0 = trainData[0].to(device)
        frame1 = trainData[-1].to(device)

        batch_size = blurred_img.shape[0]
        parallel = torch.mean(compare_ftn(start, frame0) +
                              compare_ftn(end, frame1),
                              dim=(1, 2, 3))
        cross = torch.mean(compare_ftn(start, frame1) +
                           compare_ftn(end, frame0),