def test(net, epoch_state): clean_seq = load_video(opt.dataset_dir) psnr_list = [] ssim_list = [] tpsnr_list = [] tssim_list = [] avg_time = 0 print('testing........') clean_seq = Variable(clean_seq).cuda() gt = clean_seq[:, :, int(opt.input_frame / 2 - opt.output_frame / 2):int(opt.input_frame / 2 + opt.output_frame / 2), :, :] noise = torch.empty_like(clean_seq).normal_(mean=0, std=opt.val_noiseL).to( torch.device('cuda')) noisy_clean_seq = clean_seq + noise start_time = time.time() denoise_seq = net(noisy_clean_seq) denoise_time = time.time() - start_time N, C, L, H, W = np.shape(gt) denoise_seq = denoise_seq.view(N, C, L, H, W) diff_denoise_seq, diff_gt = video_diff(denoise_seq, gt, 7) psnr_list.append(psnr(denoise_seq.detach(), gt.detach())) ssim_list.append(cal_ssim(denoise_seq, gt)) tpsnr_list.append(psnr(diff_denoise_seq.detach(), diff_gt.detach())) tssim_list.append(cal_ssim(diff_denoise_seq, diff_gt)) strred_score = cal_strred(denoise_seq, gt) print(strred_score) avg_time = avg_time + denoise_time print( 'valid PSNR---%f, SSIM---%4f, TPSNR---%4f, TSSIM---%4f, average time ----%f' % (float(np.array(psnr_list).mean()), float( np.array(ssim_list).mean()), float(np.array(tpsnr_list).mean()), float(np.array(tssim_list).mean()), avg_time)) f = open("test_delta.txt", "a+") f.write( "Epoch: %d, PSNR: %.3f, SSIM---%4f, TPSNR---%4f, TSSIM---%4f, run time: %.5f---" % (epoch_state, float(np.array(psnr_list).mean()), float(np.array(ssim_list).mean()), float(np.array(tpsnr_list).mean()), float(np.array(tssim_list).mean()), avg_time) + "\n") f.close() mkdir_if_not_exist(opt.save_dir) save_out = denoise_seq.permute(0, 2, 3, 4, 1) save_out = save_out.cpu().numpy() for frame_idx in range(7): output_img = save_out[0, frame_idx, ...] output_img[output_img > 1.0] = 1.0 output_img[output_img < 0] = 0 plt.imsave(os.path.join(opt.save_dir, str_format % frame_idx), output_img)
def test(net,epoch_state): valid_set = ValidSetLoader(opt.dataset_dir, patch_size=opt.patch_size, input_frame=opt.input_frame) valid_loader = DataLoader(dataset=valid_set, num_workers=opt.threads, batch_size=1, shuffle=False) with open(opt.dataset_dir+'/sep_testlist.txt', 'r') as f: train_list = f.read().splitlines() psnr_list = [] ssim_list= [] tpsnr_list =[] tssim_list =[] sttred_score_list = [] avg_time=0 print('testing........') for idx_iter, clean_seq in enumerate(valid_loader): #print(np.shape(clean_seq)) print("processing............",idx_iter) clean_seq = Variable(clean_seq).cuda() gt=clean_seq[:,:,int(opt.input_frame/2-opt.output_frame/2):int(opt.input_frame/2+opt.output_frame/2),:,:] noise = torch.empty_like(clean_seq).normal_(mean=0, std=opt.val_noiseL).to(torch.device('cuda')) noisy_clean_seq = clean_seq + noise start_time=time.time() denoise_seq = net(noisy_clean_seq) denoise_time=time.time()-start_time N,C,L,H,W=np.shape(gt) denoise_seq = denoise_seq.view(N,C,L,H,W) diff_denoise_seq,diff_gt=video_diff(denoise_seq,gt,7) psnr_list.append(psnr(denoise_seq.detach(), gt.detach())) ssim_list.append(cal_ssim(denoise_seq,gt)) tpsnr_list.append(psnr(diff_denoise_seq.detach(), diff_gt.detach())) tssim_list.append(cal_ssim(diff_denoise_seq,diff_gt)) sttred_score_list.append(cal_strred(denoise_seq,gt)) avg_time=avg_time+denoise_time print(train_list[idx_iter]) mkdir_if_not_exist(opt.save_dir+train_list[idx_iter]) save_out=denoise_seq.permute(0,2,3,4,1) save_out=save_out.cpu().numpy() for frame_idx in range(7): output_img=save_out[0,frame_idx,...] output_img[output_img>1.0]=1.0 output_img[output_img<0]=0 plt.imsave(os.path.join(opt.save_dir+train_list[idx_iter], str_format % frame_idx),output_img) print('valid PSNR---%f, SSIM---%4f, TPSNR---%4f, TSSIM---%4f,STRRED_SCORE---%4f, average time ----%f' % (float(np.array(psnr_list).mean()),float(np.array(ssim_list).mean()),float(np.array(tpsnr_list).mean()),float(np.array(tssim_list).mean()),float(np.array(sttred_score_list).mean()),avg_time/idx_iter)) f=open("test_delta.txt","a+") f.write("Epoch: %d, PSNR: %.3f, SSIM---%4f, TPSNR---%4f, TSSIM---%4f, STRRED_SCORE---%4f, run time: %.5f---"%(epoch_state, float(np.array(psnr_list).mean()),float(np.array(ssim_list).mean()),float(np.array(tpsnr_list).mean()),float(np.array(tssim_list).mean()),float(np.array(sttred_score_list).mean()),avg_time/idx_iter)+"\n") f.close()
def train(train_loader, scale_factor, epoch_num): net = Net(scale_factor).cuda() epoch_state = 0 loss_list = [] psnr_list = [] loss_epoch = [] psnr_epoch = [] if opt.resume: ckpt = torch.load(opt.resume) net.load_state_dict(ckpt['state_dict']) epoch_state = ckpt['epoch'] loss_list = ckpt['loss'] psnr_list = ckpt['psnr'] optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr) criterion_MSE = torch.nn.MSELoss().cuda() scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opt.step, gamma=opt.gamma) for idx_epoch in range(epoch_state, epoch_num): for idx_iter, (LR, HR) in enumerate(train_loader): LR, HR = Variable(LR).cuda(), Variable(HR).cuda() SR = net(LR) loss = criterion_MSE(SR, HR[:, :, 3, :, :]) loss_epoch.append(loss.detach().cpu()) psnr_epoch.append(psnr(SR, HR[:, :, 3, :, :])) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() if idx_epoch % 1 == 0: loss_list.append(float(np.array(loss_epoch).mean())) psnr_list.append(float(np.array(psnr_epoch).mean())) print(time.ctime()[4:-5] + ' Epoch---%d, loss_epoch---%f, PSNR---%f' % (idx_epoch + 1, float(np.array(loss_epoch).mean()), float(np.array(psnr_epoch).mean()))) save_checkpoint( { 'epoch': idx_epoch + 1, 'state_dict': net.state_dict(), 'loss': loss_list, 'psnr': psnr_list, }, save_path=opt.save, filename='model' + str(scale_factor) + '_epoch' + str(idx_epoch + 1) + '.pth.tar') loss_epoch = [] psnr_epoch = [] valid(net)
def valid(net): valid_set = ValidSetLoader(opt.train_dataset_dir, scale_factor=opt.scale_factor, inType=opt.inType) valid_loader = DataLoader(dataset=valid_set, num_workers=opt.threads, batch_size=8, shuffle=True) psnr_list = [] for idx_iter, (LR, HR) in enumerate(valid_loader): LR, HR = Variable(LR).cuda(), Variable(HR).cuda() SR = net(LR) psnr_list.append(psnr(SR.detach(), HR[:, :, 3, :, :].detach())) print('valid PSNR---%f' % (float(np.array(psnr_list).mean())))