def test(): df_column = ['Name'] df_column.extend([str(i) for i in range(1, seq_len + 1)]) df = pd.DataFrame(columns=df_column) psnr_array = np.zeros((0, seq_len)) ssim_array = np.zeros((0, seq_len)) tqdm_loader = tqdm.tqdm(validationloader, ncols=80) imgsave_folder = os.path.join(args.checkpoint_dir, 'Saved_imgs') if not os.path.exists(imgsave_folder): os.mkdir(imgsave_folder) with torch.no_grad(): for validationIndex, (validationData, validationFrameIndex, validationFile) in enumerate(tqdm_loader): blurred_img = torch.zeros_like(validationData[0]) for image in validationData: blurred_img += image blurred_img /= len(validationData) blurred_img = blurred_img.to(device) batch_size = blurred_img.shape[0] blurred_img = meanshift(blurred_img, mean, std, device, False) c = center_estimation(blurred_img) start, end = border_estimation(blurred_img, c) start = meanshift(start, mean, std, device, True) end = meanshift(end, mean, std, device, True) blurred_img = meanshift(blurred_img, mean, std, device, True) frame0 = validationData[0].to(device) frame1 = validationData[-1].to(device) batch_size = blurred_img.shape[0] parallel = torch.mean(compare_ftn(start, frame0) + compare_ftn(end, frame1), dim=(1, 2, 3)) cross = torch.mean(compare_ftn(start, frame1) + compare_ftn(end, frame0), dim=(1, 2, 3)) I0 = torch.zeros_like(blurred_img) I1 = torch.zeros_like(blurred_img) for b in range(batch_size): if parallel[b] <= cross[b]: I0[b], I1[b] = start[b], end[b] else: I0[b], I1[b] = end[b], start[b] psnrs = np.zeros((batch_size, seq_len)) ssims = np.zeros((batch_size, seq_len)) for vindex in range(seq_len): frameT = validationData[vindex] IFrame = frameT.to(device) if vindex == 0: Ft_p = I0.clone() elif vindex == seq_len - 1: Ft_p = I1.clone() else: validationIndex = torch.ones(batch_size) * (vindex - 1) validationIndex = validationIndex.long() flowOut = flowComp(torch.cat((I0, I1), dim=1)) F_0_1 = flowOut[:, :2, :, :] F_1_0 = flowOut[:, 2:, :, :] fCoeff = superslomo.getFlowCoeff(validationIndex, device, seq_len) F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0 F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0 g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0) g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1) if args.add_blur: intrpOut = ArbTimeFlowIntrp( torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0, blurred_img), dim=1)) else: intrpOut = ArbTimeFlowIntrp( torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1)) F_t_0_f = intrpOut[:, :2, :, :] + F_t_0 F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1 V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :]) V_t_1 = 1 - V_t_0 g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f) g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f) wCoeff = superslomo.getWarpCoeff(validationIndex, device, seq_len) Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / ( wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1) Ft_p = meanshift(Ft_p, mean, std, device, False) IFrame = meanshift(IFrame, mean, std, device, False) for b in range(batch_size): foldername = os.path.basename( os.path.dirname(validationFile[ctr_idx][b])) filename = os.path.splitext( os.path.basename(validationFile[vindex][b]))[0] out_fname = foldername + '_' + filename + '_out.png' gt_fname = foldername + '_' + filename + '.png' out, gt = quantize(Ft_p[b]), quantize(IFrame[b]) # Comment two lines below if you want to save images # torchvision.utils.save_image(out, os.path.join(imgsave_folder, out_fname), normalize=True, range=(0,255)) # torchvision.utils.save_image(gt, os.path.join(imgsave_folder, gt_fname), normalize=True, range=(0,255)) psnr, ssim = eval_metrics(Ft_p, IFrame) psnrs[:, vindex] = psnr.cpu().numpy() ssims[:, vindex] = ssim.cpu().numpy() for b in range(batch_size): rows = [validationFile[ctr_idx][b]] rows.extend(list(psnrs[b])) df = df.append(pd.Series(rows, index=df.columns), ignore_index=True) df.to_csv('{}/results_PSNR.csv'.format(args.checkpoint_dir))
def validate(): # For details see training. psnr = 0 tloss = 0 flag = 1 with torch.no_grad(): for validationIndex, (validationData, validationFrameIndex, _) in enumerate(validationloader, 0): blurred_img = torch.zeros_like(validationData[0]) for image in validationData: blurred_img += image blurred_img /= len(validationData) blurred_img = blurred_img.to(device) blurred_img = meanshift(blurred_img, mean, std, device, False) c = center_estimation(blurred_img) start, end = border_estimation(blurred_img, c) start = meanshift(start, mean, std, device, True) end = meanshift(end, mean, std, device, True) blurred_img = meanshift(blurred_img, mean, std, device, True) frame0 = validationData[0].to(device) frame1 = validationData[-1].to(device) batch_size = blurred_img.shape[0] parallel = torch.mean(compare_ftn(start, frame0) + compare_ftn(end, frame1), dim=(1, 2, 3)) cross = torch.mean(compare_ftn(start, frame1) + compare_ftn(end, frame0), dim=(1, 2, 3)) I0 = torch.zeros_like(blurred_img) I1 = torch.zeros_like(blurred_img) IFrame = torch.zeros_like(blurred_img) for b in range(batch_size): if parallel[b] <= cross[b]: I0[b], I1[b] = start[b], end[b] else: I0[b], I1[b] = end[b], start[b] IFrame[b] = validationData[validationFrameIndex[b] + 1][b] if args.amp: with torch.cuda.amp.autocast(): flowOut = flowComp(torch.cat((I0, I1), dim=1)) F_0_1 = flowOut[:, :2, :, :] F_1_0 = flowOut[:, 2:, :, :] fCoeff = superslomo.getFlowCoeff(validationFrameIndex, device, seq_len) F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0 F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0 g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0) g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1) if args.add_blur: intrpOut = ArbTimeFlowIntrp( torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0, blurred_img), dim=1)) else: intrpOut = ArbTimeFlowIntrp( torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1)) F_t_0_f = intrpOut[:, :2, :, :] + F_t_0 F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1 V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :]) V_t_1 = 1 - V_t_0 g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f) g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f) wCoeff = superslomo.getWarpCoeff(validationFrameIndex, device, seq_len) Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / ( wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1) #loss recnLoss = L1_lossFn(Ft_p, IFrame) prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame)) warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn( g_I1_F_t_1, IFrame) + L1_lossFn( validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn( validationFlowBackWarp(I1, F_0_1), I0) loss_smooth_1_0 = torch.mean( torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean( torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :])) loss_smooth_0_1 = torch.mean( torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean( torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :])) loss_smooth = loss_smooth_1_0 + loss_smooth_0_1 loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth else: flowOut = flowComp(torch.cat((I0, I1), dim=1)) F_0_1 = flowOut[:, :2, :, :] F_1_0 = flowOut[:, 2:, :, :] fCoeff = superslomo.getFlowCoeff(validationFrameIndex, device, seq_len) F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0 F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0 g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0) g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1) if args.add_blur: intrpOut = ArbTimeFlowIntrp( torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0, blurred_img), dim=1)) else: intrpOut = ArbTimeFlowIntrp( torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1)) F_t_0_f = intrpOut[:, :2, :, :] + F_t_0 F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1 V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :]) V_t_1 = 1 - V_t_0 g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f) g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f) wCoeff = superslomo.getWarpCoeff(validationFrameIndex, device, seq_len) Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1) #loss recnLoss = L1_lossFn(Ft_p, IFrame) prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame)) warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn( g_I1_F_t_1, IFrame) + L1_lossFn( validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn( validationFlowBackWarp(I1, F_0_1), I0) loss_smooth_1_0 = torch.mean( torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean( torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :])) loss_smooth_0_1 = torch.mean( torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean( torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :])) loss_smooth = loss_smooth_1_0 + loss_smooth_0_1 loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth # For tensorboard if flag: retImg = torchvision.utils.make_grid([ revNormalize(frame0.cpu()[0]), revNormalize(IFrame.cpu()[0]), revNormalize(Ft_p.cpu()[0]), revNormalize(frame1.cpu()[0]) ], padding=10) flag = 0 tloss += loss.item() #psnr MSE_val = MSE_LossFn(Ft_p, IFrame) psnr += (10 * log10(1 / MSE_val.item())) # Make benchmark csv file return (psnr / len(validationloader)), (tloss / len(validationloader)), retImg
def validate(): # For details see training. psnr = 0 tloss = 0 flag = 1 with torch.no_grad(): for validationIndex, (validationData, validationFrameIndex) in enumerate( validationloader, 0): frame0, frameT, frame1 = validationData I0 = frame0.to(device) I1 = frame1.to(device) IFrame = frameT.to(device) flowOut = flowComp(torch.cat((I0, I1), dim=1)) F_0_1 = flowOut[:, :2, :, :] F_1_0 = flowOut[:, 2:, :, :] fCoeff = superslomo.getFlowCoeff(validationFrameIndex, device) F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0 F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0 g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0) g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1) intrpOut = ArbTimeFlowIntrp( torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1)) F_t_0_f = intrpOut[:, :2, :, :] + F_t_0 F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1 V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :]) V_t_1 = 1 - V_t_0 g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f) g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f) wCoeff = superslomo.getWarpCoeff(validationFrameIndex, device) Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1) # For tensorboard if (flag): retImg = torchvision.utils.make_grid([ revNormalize(frame0[0]), revNormalize(frameT[0]), revNormalize(Ft_p.cpu()[0]), revNormalize(frame1[0]) ], padding=10) flag = 0 #loss recnLoss = L1_lossFn(Ft_p, IFrame) prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame)) warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn( g_I1_F_t_1, IFrame) + L1_lossFn( validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn( validationFlowBackWarp(I1, F_0_1), I0) loss_smooth_1_0 = torch.mean( torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:]) ) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :])) loss_smooth_0_1 = torch.mean( torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:]) ) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :])) loss_smooth = loss_smooth_1_0 + loss_smooth_0_1 loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth tloss += loss.item() #psnr MSE_val = MSE_LossFn(Ft_p, IFrame) psnr += (10 * log10(1 / MSE_val.item())) return (psnr / len(validationloader)), (tloss / len(validationloader)), retImg
intrpOut = ArbTimeFlowIntrp( torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1)) # Extract optical flow residuals and visibility maps F_t_0_f = intrpOut[:, :2, :, :] + F_t_0 F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1 V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :]) V_t_1 = 1 - V_t_0 # Get intermediate frames from the intermediate flows g_I0_F_t_0_f = trainFlowBackWarp(I0, F_t_0_f) g_I1_F_t_1_f = trainFlowBackWarp(I1, F_t_1_f) wCoeff = superslomo.getWarpCoeff(trainFrameIndex, device, seq_len) # Calculate final intermediate frame Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1) # Loss recnLoss = L1_lossFn(Ft_p, IFrame) prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame)) warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn( g_I1_F_t_1, IFrame) + L1_lossFn( trainFlowBackWarp(I0, F_1_0), I1) + L1_lossFn( trainFlowBackWarp(I1, F_0_1), I0)
def validate(): # For details see training. psnr = 0 tloss = 0 flag = 1 with torch.no_grad(): for validationIndex, (validationData, validationFrameIndex, _) in enumerate(validationloader, 0): # frame0, frameT, frame1 = validationData # I0 = frame0.to(device) # I1 = frame1.to(device) # IFrame = frameT.to(device) blurred_img = torch.zeros_like(validationData[0]) for image in validationData: blurred_img += image blurred_img /= len(validationData) blurred_img = blurred_img.to(device) c = center_estimation(blurred_img) start, end = border_estimation(blurred_img, c) compare_ftn = nn.L1Loss() frame0 = validationData[0].to(device) frame1 = validationData[-1].to(device) parallel = True if compare_ftn(start, frame0) + compare_ftn(end, frame1) <= compare_ftn(start, frame1) + compare_ftn(end, frame0) else False if parallel: I0, I1 = start, end else: I0, I1 = end, start frameT = torch.zeros_like(I0) for i, fidx in enumerate(validationFrameIndex): frameT[i] = validationData[fidx.item()+1][i] IFrame = frameT.to(device) flowOut = flowComp(torch.cat((I0, I1), dim=1)) F_0_1 = flowOut[:,:2,:,:] F_1_0 = flowOut[:,2:,:,:] fCoeff = superslomo.getFlowCoeff(validationFrameIndex, device, seq_len) F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0 F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0 g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0) g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1) intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0, blurred_img), dim=1)) F_t_0_f = intrpOut[:, :2, :, :] + F_t_0 F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1 V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :]) V_t_1 = 1 - V_t_0 g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f) g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f) wCoeff = superslomo.getWarpCoeff(validationFrameIndex, device, seq_len) Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1) # For tensorboard if flag: retImg = torchvision.utils.make_grid([revNormalize(frame0.cpu()[0]), revNormalize(frameT.cpu()[0]), revNormalize(Ft_p.cpu()[0]), revNormalize(frame1.cpu()[0])], padding=10) flag = 0 #loss recnLoss = L1_lossFn(Ft_p, IFrame) prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame)) warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(g_I1_F_t_1, IFrame) + L1_lossFn(validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(validationFlowBackWarp(I1, F_0_1), I0) loss_smooth_1_0 = torch.mean(torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :])) loss_smooth_0_1 = torch.mean(torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :])) loss_smooth = loss_smooth_1_0 + loss_smooth_0_1 loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth tloss += loss.item() #psnr MSE_val = MSE_LossFn(Ft_p, IFrame) psnr += (10 * log10(1 / MSE_val.item())) # Make benchmark csv file return (psnr / len(validationloader)), (tloss / len(validationloader)), retImg