示例#1
0
def test():
    df_column = ['Name']
    df_column.extend([str(i) for i in range(1, seq_len + 1)])

    df = pd.DataFrame(columns=df_column)

    psnr_array = np.zeros((0, seq_len))
    ssim_array = np.zeros((0, seq_len))

    tqdm_loader = tqdm.tqdm(validationloader, ncols=80)

    imgsave_folder = os.path.join(args.checkpoint_dir, 'Saved_imgs')
    if not os.path.exists(imgsave_folder):
        os.mkdir(imgsave_folder)

    with torch.no_grad():
        for validationIndex, (validationData, validationFrameIndex,
                              validationFile) in enumerate(tqdm_loader):

            blurred_img = torch.zeros_like(validationData[0])
            for image in validationData:
                blurred_img += image
            blurred_img /= len(validationData)
            blurred_img = blurred_img.to(device)
            batch_size = blurred_img.shape[0]

            blurred_img = meanshift(blurred_img, mean, std, device, False)
            c = center_estimation(blurred_img)
            start, end = border_estimation(blurred_img, c)
            start = meanshift(start, mean, std, device, True)
            end = meanshift(end, mean, std, device, True)
            blurred_img = meanshift(blurred_img, mean, std, device, True)

            frame0 = validationData[0].to(device)
            frame1 = validationData[-1].to(device)

            batch_size = blurred_img.shape[0]
            parallel = torch.mean(compare_ftn(start, frame0) +
                                  compare_ftn(end, frame1),
                                  dim=(1, 2, 3))
            cross = torch.mean(compare_ftn(start, frame1) +
                               compare_ftn(end, frame0),
                               dim=(1, 2, 3))

            I0 = torch.zeros_like(blurred_img)
            I1 = torch.zeros_like(blurred_img)
            for b in range(batch_size):
                if parallel[b] <= cross[b]:
                    I0[b], I1[b] = start[b], end[b]
                else:
                    I0[b], I1[b] = end[b], start[b]

            psnrs = np.zeros((batch_size, seq_len))
            ssims = np.zeros((batch_size, seq_len))

            for vindex in range(seq_len):
                frameT = validationData[vindex]
                IFrame = frameT.to(device)

                if vindex == 0:
                    Ft_p = I0.clone()

                elif vindex == seq_len - 1:
                    Ft_p = I1.clone()

                else:
                    validationIndex = torch.ones(batch_size) * (vindex - 1)
                    validationIndex = validationIndex.long()
                    flowOut = flowComp(torch.cat((I0, I1), dim=1))
                    F_0_1 = flowOut[:, :2, :, :]
                    F_1_0 = flowOut[:, 2:, :, :]

                    fCoeff = superslomo.getFlowCoeff(validationIndex, device,
                                                     seq_len)

                    F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                    F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                    g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
                    g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)

                    if args.add_blur:
                        intrpOut = ArbTimeFlowIntrp(
                            torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                       g_I1_F_t_1, g_I0_F_t_0, blurred_img),
                                      dim=1))
                    else:
                        intrpOut = ArbTimeFlowIntrp(
                            torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                       g_I1_F_t_1, g_I0_F_t_0),
                                      dim=1))

                    F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                    F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                    V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
                    V_t_1 = 1 - V_t_0

                    g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
                    g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)

                    wCoeff = superslomo.getWarpCoeff(validationIndex, device,
                                                     seq_len)

                    Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f +
                            wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (
                                wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                Ft_p = meanshift(Ft_p, mean, std, device, False)
                IFrame = meanshift(IFrame, mean, std, device, False)

                for b in range(batch_size):
                    foldername = os.path.basename(
                        os.path.dirname(validationFile[ctr_idx][b]))
                    filename = os.path.splitext(
                        os.path.basename(validationFile[vindex][b]))[0]

                    out_fname = foldername + '_' + filename + '_out.png'
                    gt_fname = foldername + '_' + filename + '.png'
                    out, gt = quantize(Ft_p[b]), quantize(IFrame[b])

                    # Comment two lines below if you want to save images
                    # torchvision.utils.save_image(out, os.path.join(imgsave_folder, out_fname), normalize=True, range=(0,255))
                    # torchvision.utils.save_image(gt, os.path.join(imgsave_folder, gt_fname), normalize=True, range=(0,255))

                psnr, ssim = eval_metrics(Ft_p, IFrame)
                psnrs[:, vindex] = psnr.cpu().numpy()
                ssims[:, vindex] = ssim.cpu().numpy()

            for b in range(batch_size):
                rows = [validationFile[ctr_idx][b]]
                rows.extend(list(psnrs[b]))
                df = df.append(pd.Series(rows, index=df.columns),
                               ignore_index=True)

            df.to_csv('{}/results_PSNR.csv'.format(args.checkpoint_dir))
示例#2
0
def validate():
    # For details see training.
    psnr = 0
    tloss = 0
    flag = 1
    with torch.no_grad():
        for validationIndex, (validationData, validationFrameIndex,
                              _) in enumerate(validationloader, 0):
            blurred_img = torch.zeros_like(validationData[0])
            for image in validationData:
                blurred_img += image
            blurred_img /= len(validationData)
            blurred_img = blurred_img.to(device)

            blurred_img = meanshift(blurred_img, mean, std, device, False)
            c = center_estimation(blurred_img)
            start, end = border_estimation(blurred_img, c)
            start = meanshift(start, mean, std, device, True)
            end = meanshift(end, mean, std, device, True)
            blurred_img = meanshift(blurred_img, mean, std, device, True)

            frame0 = validationData[0].to(device)
            frame1 = validationData[-1].to(device)

            batch_size = blurred_img.shape[0]
            parallel = torch.mean(compare_ftn(start, frame0) +
                                  compare_ftn(end, frame1),
                                  dim=(1, 2, 3))
            cross = torch.mean(compare_ftn(start, frame1) +
                               compare_ftn(end, frame0),
                               dim=(1, 2, 3))

            I0 = torch.zeros_like(blurred_img)
            I1 = torch.zeros_like(blurred_img)
            IFrame = torch.zeros_like(blurred_img)
            for b in range(batch_size):
                if parallel[b] <= cross[b]:
                    I0[b], I1[b] = start[b], end[b]
                else:
                    I0[b], I1[b] = end[b], start[b]

                IFrame[b] = validationData[validationFrameIndex[b] + 1][b]

            if args.amp:
                with torch.cuda.amp.autocast():
                    flowOut = flowComp(torch.cat((I0, I1), dim=1))
                    F_0_1 = flowOut[:, :2, :, :]
                    F_1_0 = flowOut[:, 2:, :, :]

                    fCoeff = superslomo.getFlowCoeff(validationFrameIndex,
                                                     device, seq_len)

                    F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                    F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                    g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
                    g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)

                    if args.add_blur:
                        intrpOut = ArbTimeFlowIntrp(
                            torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                       g_I1_F_t_1, g_I0_F_t_0, blurred_img),
                                      dim=1))
                    else:
                        intrpOut = ArbTimeFlowIntrp(
                            torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                       g_I1_F_t_1, g_I0_F_t_0),
                                      dim=1))

                    F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                    F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                    V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
                    V_t_1 = 1 - V_t_0

                    g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
                    g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)

                    wCoeff = superslomo.getWarpCoeff(validationFrameIndex,
                                                     device, seq_len)

                    Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f +
                            wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (
                                wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                    #loss
                    recnLoss = L1_lossFn(Ft_p, IFrame)
                    prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p),
                                          vgg16_conv_4_3(IFrame))
                    warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                        g_I1_F_t_1, IFrame) + L1_lossFn(
                            validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                                validationFlowBackWarp(I1, F_0_1), I0)

                    loss_smooth_1_0 = torch.mean(
                        torch.abs(F_1_0[:, :, :, :-1] -
                                  F_1_0[:, :, :, 1:])) + torch.mean(
                                      torch.abs(F_1_0[:, :, :-1, :] -
                                                F_1_0[:, :, 1:, :]))
                    loss_smooth_0_1 = torch.mean(
                        torch.abs(F_0_1[:, :, :, :-1] -
                                  F_0_1[:, :, :, 1:])) + torch.mean(
                                      torch.abs(F_0_1[:, :, :-1, :] -
                                                F_0_1[:, :, 1:, :]))
                    loss_smooth = loss_smooth_1_0 + loss_smooth_0_1

                    loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth

            else:
                flowOut = flowComp(torch.cat((I0, I1), dim=1))
                F_0_1 = flowOut[:, :2, :, :]
                F_1_0 = flowOut[:, 2:, :, :]

                fCoeff = superslomo.getFlowCoeff(validationFrameIndex, device,
                                                 seq_len)

                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)

                if args.add_blur:
                    intrpOut = ArbTimeFlowIntrp(
                        torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                   g_I1_F_t_1, g_I0_F_t_0, blurred_img),
                                  dim=1))
                else:
                    intrpOut = ArbTimeFlowIntrp(
                        torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                   g_I1_F_t_1, g_I0_F_t_0),
                                  dim=1))

                F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
                V_t_1 = 1 - V_t_0

                g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
                g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)

                wCoeff = superslomo.getWarpCoeff(validationFrameIndex, device,
                                                 seq_len)

                Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                        g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                #loss
                recnLoss = L1_lossFn(Ft_p, IFrame)
                prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p),
                                      vgg16_conv_4_3(IFrame))
                warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                    g_I1_F_t_1, IFrame) + L1_lossFn(
                        validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                            validationFlowBackWarp(I1, F_0_1), I0)

                loss_smooth_1_0 = torch.mean(
                    torch.abs(F_1_0[:, :, :, :-1] -
                              F_1_0[:, :, :, 1:])) + torch.mean(
                                  torch.abs(F_1_0[:, :, :-1, :] -
                                            F_1_0[:, :, 1:, :]))
                loss_smooth_0_1 = torch.mean(
                    torch.abs(F_0_1[:, :, :, :-1] -
                              F_0_1[:, :, :, 1:])) + torch.mean(
                                  torch.abs(F_0_1[:, :, :-1, :] -
                                            F_0_1[:, :, 1:, :]))
                loss_smooth = loss_smooth_1_0 + loss_smooth_0_1

                loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth

            # For tensorboard
            if flag:
                retImg = torchvision.utils.make_grid([
                    revNormalize(frame0.cpu()[0]),
                    revNormalize(IFrame.cpu()[0]),
                    revNormalize(Ft_p.cpu()[0]),
                    revNormalize(frame1.cpu()[0])
                ],
                                                     padding=10)
                flag = 0

            tloss += loss.item()
            #psnr
            MSE_val = MSE_LossFn(Ft_p, IFrame)
            psnr += (10 * log10(1 / MSE_val.item()))

            # Make benchmark csv file

    return (psnr / len(validationloader)), (tloss /
                                            len(validationloader)), retImg
示例#3
0
def validate():
    # For details see training.
    psnr = 0
    tloss = 0
    flag = 1
    with torch.no_grad():
        for validationIndex, (validationData,
                              validationFrameIndex) in enumerate(
                                  validationloader, 0):
            frame0, frameT, frame1 = validationData

            I0 = frame0.to(device)
            I1 = frame1.to(device)
            IFrame = frameT.to(device)

            flowOut = flowComp(torch.cat((I0, I1), dim=1))
            F_0_1 = flowOut[:, :2, :, :]
            F_1_0 = flowOut[:, 2:, :, :]

            fCoeff = superslomo.getFlowCoeff(validationFrameIndex, device)

            F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
            F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

            g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
            g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)

            intrpOut = ArbTimeFlowIntrp(
                torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                           g_I0_F_t_0),
                          dim=1))

            F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
            F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
            V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
            V_t_1 = 1 - V_t_0

            g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
            g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)

            wCoeff = superslomo.getWarpCoeff(validationFrameIndex, device)

            Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                    g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

            # For tensorboard
            if (flag):
                retImg = torchvision.utils.make_grid([
                    revNormalize(frame0[0]),
                    revNormalize(frameT[0]),
                    revNormalize(Ft_p.cpu()[0]),
                    revNormalize(frame1[0])
                ],
                                                     padding=10)
                flag = 0

            #loss
            recnLoss = L1_lossFn(Ft_p, IFrame)

            prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))

            warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                g_I1_F_t_1, IFrame) + L1_lossFn(
                    validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                        validationFlowBackWarp(I1, F_0_1), I0)

            loss_smooth_1_0 = torch.mean(
                torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])
            ) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))
            loss_smooth_0_1 = torch.mean(
                torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])
            ) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))
            loss_smooth = loss_smooth_1_0 + loss_smooth_0_1

            loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth
            tloss += loss.item()

            #psnr
            MSE_val = MSE_LossFn(Ft_p, IFrame)
            psnr += (10 * log10(1 / MSE_val.item()))

    return (psnr / len(validationloader)), (tloss /
                                            len(validationloader)), retImg
示例#4
0
                I0[b], I1[b] = end[b], start[b]

            IFrame[b] = trainData[trainFrameIndex[b] + 1][b]

        optimizer.zero_grad()

        if args.amp:
            with torch.cuda.amp.autocast():
                # Calculate flow between reference frames I0 and I1
                flowOut = flowComp(torch.cat((I0, I1), dim=1))

                # Extracting flows between I0 and I1 - F_0_1 and F_1_0
                F_0_1 = flowOut[:, :2, :, :]
                F_1_0 = flowOut[:, 2:, :, :]

                fCoeff = superslomo.getFlowCoeff(trainFrameIndex, device,
                                                 seq_len)

                # Calculate intermediate flows
                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                # Get intermediate frames from the intermediate flows
                g_I0_F_t_0 = trainFlowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = trainFlowBackWarp(I1, F_t_1)

                # Calculate optical flow residuals and visibility maps
                if args.add_blur:
                    intrpOut = ArbTimeFlowIntrp(
                        torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0,
                                   g_I1_F_t_1, g_I0_F_t_0, blurred_img),
                                  dim=1))
示例#5
0
def validate():
    # For details see training.
    psnr = 0
    tloss = 0
    flag = 1
    with torch.no_grad():
        for validationIndex, (validationData, validationFrameIndex, _) in enumerate(validationloader, 0):
            # frame0, frameT, frame1 = validationData

            # I0 = frame0.to(device)
            # I1 = frame1.to(device)
            # IFrame = frameT.to(device)

            blurred_img = torch.zeros_like(validationData[0])
            for image in validationData:
                blurred_img += image
            blurred_img /= len(validationData)
            blurred_img = blurred_img.to(device)
            
            c = center_estimation(blurred_img)
            start, end = border_estimation(blurred_img, c)

            compare_ftn = nn.L1Loss()
            frame0 = validationData[0].to(device)
            frame1 = validationData[-1].to(device)
            parallel = True if compare_ftn(start, frame0) + compare_ftn(end, frame1) <= compare_ftn(start, frame1) + compare_ftn(end, frame0) else False

            if parallel:
                I0, I1 = start, end
            else:
                I0, I1 = end, start

            frameT = torch.zeros_like(I0)
            for i, fidx in enumerate(validationFrameIndex):
                frameT[i] = validationData[fidx.item()+1][i]

            IFrame = frameT.to(device)
            
            flowOut = flowComp(torch.cat((I0, I1), dim=1))
            F_0_1 = flowOut[:,:2,:,:]
            F_1_0 = flowOut[:,2:,:,:]

            fCoeff = superslomo.getFlowCoeff(validationFrameIndex, device, seq_len)

            F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
            F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

            g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
            g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)
            
            intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0, blurred_img), dim=1))
                
            F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
            F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
            V_t_0   = torch.sigmoid(intrpOut[:, 4:5, :, :])
            V_t_1   = 1 - V_t_0
                
            g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
            g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)
            
            wCoeff = superslomo.getWarpCoeff(validationFrameIndex, device, seq_len)
            
            Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)
            
            # For tensorboard
            if flag:
                retImg = torchvision.utils.make_grid([revNormalize(frame0.cpu()[0]), revNormalize(frameT.cpu()[0]), revNormalize(Ft_p.cpu()[0]), revNormalize(frame1.cpu()[0])], padding=10)
                flag = 0
            
            #loss
            recnLoss = L1_lossFn(Ft_p, IFrame)
            
            prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))
            
            warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(g_I1_F_t_1, IFrame) + L1_lossFn(validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(validationFlowBackWarp(I1, F_0_1), I0)
        
            loss_smooth_1_0 = torch.mean(torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))
            loss_smooth_0_1 = torch.mean(torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))
            loss_smooth = loss_smooth_1_0 + loss_smooth_0_1
            
            
            loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth
            tloss += loss.item()
            
            #psnr
            MSE_val = MSE_LossFn(Ft_p, IFrame)
            psnr += (10 * log10(1 / MSE_val.item()))

            # Make benchmark csv file

    return (psnr / len(validationloader)), (tloss / len(validationloader)), retImg