def ToImage(frame0, frame1): with torch.no_grad(): img0 = frame0.cuda() img1 = frame1.cuda() img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1) img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1) ref_imgt, _ = structure_gen((img0_e, img1_e)) imgt = detail_enhance((img0, img1, ref_imgt)) # imgt = detail_enhance((img0, img1, imgt)) imgt = torch.clamp(imgt, max=1., min=-1.) return imgt
for trainIndex, trainData in tqdm(enumerate(trainloader, 0)): # Getting the input and the target from the training set start_time = time.time() frame0, frameT, frame1 = trainData if args.test: """ just for 1 batch test """ frame0, frameT, frame1 = test_batch img0 = frame0.cuda() img1 = frame1.cuda() IFrame = frameT.cuda() img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1) img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1) IFrame_e = torch.cat([IFrame, torch.tanh(bdcn(IFrame)[0])], dim=1) if args.final: _, _, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e)) _, _, _, ref_imgt2 = detail_enhance((img0, img1, IFrame, ref_imgt)) loss, MSE_val, imgt = detail_enhance_last((img0, img1, IFrame, ref_imgt2)) else: if args.GEN_DE: loss, _, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e)) else: _, _, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e)) loss, _, _, imgt = detail_enhance((img0, img1, IFrame, ref_imgt)) # print(torch.max(torch.abs(ref_imgt - IFrame)))
def validate(): # For details see training. psnr = 0 ssim = 0 tloss = 0 flag = 1 with torch.no_grad(): for validationIndex, validationData in enumerate(validationloader, 0): frame0, frameT, frame1 = validationData img0 = frame0.cuda() img1 = frame1.cuda() IFrame = frameT.cuda() img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1) img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1) IFrame_e = torch.cat([IFrame, torch.tanh(bdcn(IFrame)[0])], dim=1) if args.final: _, _, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e)) _, _, _, ref_imgt2 = detail_enhance((img0, img1, IFrame, ref_imgt)) loss, MSE_val, imgt = detail_enhance_last((img0, img1, IFrame, ref_imgt2)) SSIM = compare_ssim(IFrame.reshape(1, -1, 256, 448).squeeze(0).cpu().numpy().transpose(1, 2, 0), imgt.reshape(1, -1, 256, 448).squeeze(0).cpu().numpy().transpose(1, 2, 0), data_range=2, multichannel=True) else: if args.GEN_DE: loss, MSE_val, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e)) SSIM = compare_ssim(IFrame.reshape(1, -1, 256, 448).squeeze(0).cpu().numpy().transpose(1, 2, 0), ref_imgt.reshape(1, -1, 256, 448).squeeze(0).cpu().numpy().transpose(1, 2, 0), data_range=2, multichannel=True) else: _, _, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e)) loss, MSE_val, _, imgt = detail_enhance((img0, img1, IFrame, ref_imgt)) SSIM = compare_ssim(IFrame.reshape(1, -1, 256, 448).squeeze(0).cpu().numpy().transpose(1, 2, 0), imgt.reshape(1, -1, 256, 448).squeeze(0).cpu().numpy().transpose(1, 2, 0), data_range=2, multichannel=True) loss = torch.mean(loss) MSE_val = torch.mean(MSE_val) if (flag): if args.final: retImg = torch.cat([revNormalize(frame0[0]).unsqueeze(0), revNormalize(frame1[0]).unsqueeze(0), revNormalize(imgt.cpu()[0]).unsqueeze(0), revNormalize(frameT[0]).unsqueeze(0), revNormalize(ref_imgt2.cpu()[0]).unsqueeze(0), revNormalize(ref_imgt.cpu()[0]).unsqueeze(0)], dim=0) else: if args.GEN_DE: retImg = torch.cat([revNormalize(frame0[0]).unsqueeze(0), revNormalize(frame1[0]).unsqueeze(0), revNormalize(ref_imgt.cpu()[0]).unsqueeze(0), revNormalize(frameT[0]).unsqueeze(0)], dim=0) else: retImg = torch.cat([revNormalize(frame0[0]).unsqueeze(0), revNormalize(frame1[0]).unsqueeze(0), revNormalize(imgt.cpu()[0]).unsqueeze(0), revNormalize(frameT[0]).unsqueeze(0), revNormalize(ref_imgt.cpu()[0]).unsqueeze(0)], dim=0) flag = 0 # psnr tloss += loss.item() psnr += (10 * log10(4 / MSE_val.item())) ssim += SSIM return (psnr / len(validationloader)), (tloss / len(validationloader)), retImg, MSE_val, (ssim / len(validationloader))
def validate(): # For details see training. psnr = 0 ie = 0 tloss = 0 with torch.no_grad(): for testIndex, testData in tqdm(enumerate(testloader, 0)): frame0, frameT, frame1 = testData img0 = frame0.cuda() img1 = frame1.cuda() IFrame = frameT.cuda() img0_e = torch.cat([img0, torch.tanh(bdcn(img0)[0])], dim=1) img1_e = torch.cat([img1, torch.tanh(bdcn(img1)[0])], dim=1) IFrame_e = torch.cat([IFrame, torch.tanh(bdcn(IFrame)[0])], dim=1) _, _, ref_imgt = structure_gen((img0_e, img1_e, IFrame_e)) loss, MSE_val, IE, imgt = detail_enhance( (img0, img1, IFrame, ref_imgt)) imgt = torch.clamp(imgt, max=1., min=-1.) IFrame_np = IFrame.squeeze(0).cpu().numpy() imgt_np = imgt.squeeze(0).cpu().numpy() imgt_png = np.uint8( ((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) IFrame_png = np.uint8( ((IFrame_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) imgpath = args.imgpath + '/' + str(testIndex) if not os.path.isdir(imgpath): os.system('mkdir -p %s' % imgpath) cv2.imwrite(imgpath + '/imgt.png', imgt_png) cv2.imwrite(imgpath + '/IFrame.png', IFrame_png) PSNR = compare_psnr(IFrame_np, imgt_np, data_range=2) print('PSNR:', PSNR) loss = torch.mean(loss) MSE_val = torch.mean(MSE_val) if testIndex % 100 == 99: vImg = torch.cat([ revNormalize(frame0[0]).unsqueeze(0), revNormalize(frame1[0]).unsqueeze(0), revNormalize(imgt.cpu()[0]).unsqueeze(0), revNormalize(frameT[0]).unsqueeze(0), revNormalize(ref_imgt.cpu()[0]).unsqueeze(0) ], dim=0) vImg = torch.clamp(vImg, max=1., min=0) vis.images(vImg, win='vImage', env=args.visdom_env, nrow=2, opts={'title': 'visual_image'}) # psnr tloss += loss.item() psnr += PSNR ie += IE return (psnr / len(testloader)), (tloss / len(testloader)), MSE_val, (ie / len(testloader))