def valid_loss_function(model, val_loader,epoch,save_freq): psnr=0 final_loss=0 g_loss = np.zeros((5000, 1)) model.eval() for batch_idx, (inputs, targets) in enumerate(val_loader): inputs, targets = inputs, targets in_img = inputs target = targets out_img = model(in_img) out_img = out_img loss = reduce_mean(out_img['output1'], target) g_loss[batch_idx] = loss.data.cpu() final_loss = np.mean(g_loss[np.where(g_loss)]) print("%d %d Loss=%.10f" % (epoch, batch_idx, final_loss)) if epoch % save_freq == 0: if not os.path.isdir(result_dir_val + '%04d' % epoch): os.makedirs(result_dir_val + '%04d' % epoch) out_img=transforms.ToPILImage()(out_img['output1'][0].cpu()) target_img = transforms.ToPILImage()(target[0].cpu()) out_img.save(result_dir_val +'/%04d/' % epoch+ '%04dDBN-FlorinDS_MSE_FS_30_00_train_%d.jpg' % (epoch, batch_idx)) target_img.save(result_dir_val +'/%04d/' % epoch+ '%04dDBN-FlorinDS_MSE_FS_30_00_target_%d.jpg' % (epoch, batch_idx)) psnr = metrics.PSNR(target_img, out_img) return final_loss, psnr
def test_function(model, test_loader,epoch,save_freq, result_dir_val,experiment_name,plotter1,plotter2): out_imgs=[] target_imgs=[] input_imgs=[] batch_i=0 psnrs= [] ssims=[] average_psnr=0 average_ssim=0 final_loss=0 model.eval() for batch_idx, (inputs, targets,input_refs) in enumerate(test_loader): inputs, targets,input_refs = inputs, targets,input_refs in_img = inputs target = targets input_ref= input_refs out_img = model(in_img) out_img = out_img out_imgs=transforms.ToPILImage()(out_img['output1'][0].cpu()) target_imgs = transforms.ToPILImage()(target[0].cpu()) input_imgs = transforms.ToPILImage()(input_ref[0].cpu()) if epoch % save_freq == 0: if not os.path.isdir(result_dir_val + '%04d' % epoch): os.makedirs(result_dir_val + '%04d' % epoch) psnrs.append(metrics.PSNR(target_imgs, out_imgs)) ssims.append(metrics.SSIM(target_imgs, out_imgs)) # paralel_imgs = [] if batch_idx% 40==0: input_imgs.save( result_dir_val + '/%04d/' % epoch + '%04dDBN_FD_20s_512_00_input_%d.jpg' % (epoch, batch_idx)) out_imgs.save( result_dir_val + '/%04d/' % epoch + '%04dDBN_FD_20s_512_00_train_%d-PSNR-%f.jpg' % ( epoch, batch_idx,psnrs[batch_idx])) target_imgs.save( result_dir_val + '/%04d/' % epoch + '%04dDBN_FD_20s_512_00_target_%d.jpg' % (epoch, batch_idx)) # paralel_imgs.append(input_imgs) # paralel_imgs.append(out_imgs) # paralel_imgs.append(target_imgs) # UtilsImage.uniImage(paralel_imgs).save( # result_dir_val + '/%04d/' % epoch + '%04dDBN_D_ER_FDS_MSE_FS_Result_%d-PSNR-%d.jpg' % ( # epoch, batch_i, psnrs[batch_idx])) # if epoch % save_freq == 0: # if not os.path.isdir(result_dir_val + '%04d' % epoch): # os.makedirs(result_dir_val + '%04d' % epoch) # for batch_i in range(0,test_loader): # psnrs[batch_i] = metrics.PSNR(target_imgs[batch_i], out_imgs[batch_i]) # ssims[batch_i] = metrics.SSIM(target_imgs[batch_i], out_imgs[batch_i]) # if batch_i % 20 ==0: # paralel_imgs =[] # input_imgs[batch_i].save(result_dir_val +'/%04d/' % epoch+ '%04dDBNP_D_ER_FDS_MSE_FS_00_input_%d.jpg' % (epoch, batch_i)) # out_imgs[batch_i].save(result_dir_val +'/%04d/' % epoch+ '%04dDBNP_D_ER_FDS_MSE_FS_00_train_%d-PSNR-%d.jpg' % (epoch, batch_i,psnrs[batch_i])) # target_imgs[batch_i].save(result_dir_val +'/%04d/' % epoch+ '%04dDBN_D_ER_FDS_MSE_FS_00_target_%d.jpg' % (epoch, batch_i)) # paralel_imgs.append( input_imgs[batch_i] ) # paralel_imgs.append (out_imgs[batch_i] ) # paralel_imgs.append(target_imgs[batch_i] ) # UtilsImage.uniImage(paralel_imgs).save(result_dir_val +'/%04d/' % epoch+ '%04dDBN_D_ER_FDS_MSE_FS_Result_%d-PSNR-%d.jpg' % (epoch, batch_i,psnrs[batch_i])) # # wandb.log({ '%04dDBN_D_ER_FDS_MSE_FS_Result_%d.jpg' % (epoch, batch_i) : wandb.Image( Utils.uniImage(paralel_imgs)) # # , "PSNR: " :psnrs[batch_i] # # }) for i in range(0,psnrs.__len__()): average_psnr += psnrs[i] average_ssim += ssims[i] average_psnr=average_psnr/psnrs.__len__() average_ssim=average_ssim/ssims.__len__() print(epoch,average_psnr,average_ssim) writer.writerow([epoch, average_psnr, average_ssim]) plotter1.plot("average_PSNR", 'Val-PSNR', "Epoch", epoch, average_psnr) plotter2.plot("average_SSIM", 'Val-SSIM', "Epoch", epoch, average_ssim)
out_img=transforms.ToPILImage()(outputs['output1'][0].cpu()) target_img = transforms.ToPILImage()(target[0].cpu()) out_img.save(result_dir_train +'/%04d/' % epoch+ '%04dFlorinDS_MSE_FS_30_00_train_%d.jpg' % (epoch, batch_idx)) target_img.save(result_dir_train +'/%04d/' % epoch+ '%04dFlorinDS_MSE_FS_30_00_target_%d.jpg' % (epoch, batch_idx)) if epoch % save_freq_model == 0: torch.save(model.state_dict(), model_dir + 'FlorinDS_MSE_FS_30_Checkpoint_e%04d' % epoch) model_name='FlorinDS_MSE_FS_30_Checkpoint_e%04d' % epoch if epoch % save_freq == 0 and batch_idx == (len(train_loader) - 1): model_V= M.DeBlurNet().cpu() model_V.load_state_dict(torch.load(model_dir + model_name)) my_val_loss, val_psnr = valid_loss_function(model_V.cpu(), val_loader, epoch, save_freq) print("Validation Loss=%.10f" % my_val_loss) plotter1V.plot('loss', 'Validation Loss', 'Epoch', epoch, my_val_loss) print("PSNR: " + str(val_psnr)) plotter2V.plot("PSNR", 'Val_PSNR', 'Epoch', epoch, val_psnr) psnr=metrics.PSNR(target_img,out_img) print("PSNR: " + str(psnr)) # ssim=metrics.SSIM(target_img,out_img) plotter1.plot('Loss', 'Train Loss', 'Epoch', epoch, final_loss) plotter2.plot("PSNR", 'Train_PSNR', 'Epoch', epoch, psnr) t1=time.time() total=t1-to print(total) # torch.save(model.state_dict(), model_dir + 'checkpoint_curr_e%04d' % epoch)