Пример #1
0
def valid_loss_function(model, val_loader,epoch,save_freq):
    psnr=0
    final_loss=0
    g_loss = np.zeros((5000, 1))
    model.eval()
    for batch_idx, (inputs, targets) in enumerate(val_loader):
        inputs, targets = inputs, targets


        in_img = inputs
        target = targets

        out_img = model(in_img)
        out_img = out_img

        loss = reduce_mean(out_img['output1'], target)

        g_loss[batch_idx] = loss.data.cpu()
        final_loss = np.mean(g_loss[np.where(g_loss)])
        print("%d %d Loss=%.10f" % (epoch, batch_idx, final_loss))
    if epoch % save_freq == 0:
        if not os.path.isdir(result_dir_val + '%04d' % epoch):
            os.makedirs(result_dir_val + '%04d' % epoch)

        out_img=transforms.ToPILImage()(out_img['output1'][0].cpu())
        target_img = transforms.ToPILImage()(target[0].cpu())
        out_img.save(result_dir_val +'/%04d/' % epoch+ '%04dDBN-FlorinDS_MSE_FS_30_00_train_%d.jpg' % (epoch, batch_idx))
        target_img.save(result_dir_val +'/%04d/' % epoch+ '%04dDBN-FlorinDS_MSE_FS_30_00_target_%d.jpg' % (epoch, batch_idx))
    psnr = metrics.PSNR(target_img, out_img)

    return final_loss, psnr
Пример #2
0
    def test_function(model, test_loader,epoch,save_freq, result_dir_val,experiment_name,plotter1,plotter2):

        out_imgs=[]
        target_imgs=[]
        input_imgs=[]
        batch_i=0
        psnrs= []
        ssims=[]
        average_psnr=0
        average_ssim=0
        final_loss=0
        model.eval()
        for batch_idx, (inputs, targets,input_refs) in enumerate(test_loader):
            inputs, targets,input_refs = inputs, targets,input_refs


            in_img = inputs
            target = targets
            input_ref= input_refs


            out_img = model(in_img)
            out_img = out_img

            out_imgs=transforms.ToPILImage()(out_img['output1'][0].cpu())
            target_imgs = transforms.ToPILImage()(target[0].cpu())
            input_imgs = transforms.ToPILImage()(input_ref[0].cpu())
            if epoch % save_freq == 0:
                if not os.path.isdir(result_dir_val + '%04d' % epoch):
                    os.makedirs(result_dir_val + '%04d' % epoch)
                psnrs.append(metrics.PSNR(target_imgs, out_imgs))
                ssims.append(metrics.SSIM(target_imgs, out_imgs))

                # paralel_imgs = []
                if batch_idx% 40==0:
                    input_imgs.save(
                        result_dir_val + '/%04d/' % epoch + '%04dDBN_FD_20s_512_00_input_%d.jpg' % (epoch, batch_idx))
                    out_imgs.save(
                        result_dir_val + '/%04d/' % epoch + '%04dDBN_FD_20s_512_00_train_%d-PSNR-%f.jpg' % (
                        epoch, batch_idx,psnrs[batch_idx]))
                    target_imgs.save(
                        result_dir_val + '/%04d/' % epoch + '%04dDBN_FD_20s_512_00_target_%d.jpg' % (epoch, batch_idx))
                # paralel_imgs.append(input_imgs)
                # paralel_imgs.append(out_imgs)
                # paralel_imgs.append(target_imgs)
                # UtilsImage.uniImage(paralel_imgs).save(
                #     result_dir_val + '/%04d/' % epoch + '%04dDBN_D_ER_FDS_MSE_FS_Result_%d-PSNR-%d.jpg' % (
                #     epoch, batch_i, psnrs[batch_idx]))

        # if epoch % save_freq == 0:
        #     if not os.path.isdir(result_dir_val + '%04d' % epoch):
        #         os.makedirs(result_dir_val + '%04d' % epoch)

            # for batch_i in range(0,test_loader):
            #     psnrs[batch_i] = metrics.PSNR(target_imgs[batch_i], out_imgs[batch_i])
            #     ssims[batch_i] = metrics.SSIM(target_imgs[batch_i], out_imgs[batch_i])
            #     if batch_i % 20 ==0:
            #         paralel_imgs =[]
            #         input_imgs[batch_i].save(result_dir_val +'/%04d/' % epoch+ '%04dDBNP_D_ER_FDS_MSE_FS_00_input_%d.jpg' % (epoch, batch_i))
            #         out_imgs[batch_i].save(result_dir_val +'/%04d/' % epoch+ '%04dDBNP_D_ER_FDS_MSE_FS_00_train_%d-PSNR-%d.jpg' % (epoch, batch_i,psnrs[batch_i]))
            #         target_imgs[batch_i].save(result_dir_val +'/%04d/' % epoch+ '%04dDBN_D_ER_FDS_MSE_FS_00_target_%d.jpg' % (epoch, batch_i))
            #         paralel_imgs.append( input_imgs[batch_i]  )
            #         paralel_imgs.append (out_imgs[batch_i]  )
            #         paralel_imgs.append(target_imgs[batch_i]  )
            #         UtilsImage.uniImage(paralel_imgs).save(result_dir_val +'/%04d/' % epoch+ '%04dDBN_D_ER_FDS_MSE_FS_Result_%d-PSNR-%d.jpg' % (epoch, batch_i,psnrs[batch_i]))
            #         # wandb.log({ '%04dDBN_D_ER_FDS_MSE_FS_Result_%d.jpg' % (epoch, batch_i) : wandb.Image( Utils.uniImage(paralel_imgs))
            #         #             , "PSNR: " :psnrs[batch_i]
            #         #             })
        for i in range(0,psnrs.__len__()):
            average_psnr += psnrs[i]
            average_ssim += ssims[i]

        average_psnr=average_psnr/psnrs.__len__()
        average_ssim=average_ssim/ssims.__len__()
        print(epoch,average_psnr,average_ssim)
        writer.writerow([epoch, average_psnr, average_ssim])

        plotter1.plot("average_PSNR", 'Val-PSNR', "Epoch", epoch, average_psnr)
        plotter2.plot("average_SSIM", 'Val-SSIM', "Epoch", epoch, average_ssim)
Пример #3
0
            out_img=transforms.ToPILImage()(outputs['output1'][0].cpu())
            target_img = transforms.ToPILImage()(target[0].cpu())
            out_img.save(result_dir_train +'/%04d/' % epoch+ '%04dFlorinDS_MSE_FS_30_00_train_%d.jpg' % (epoch, batch_idx))
            target_img.save(result_dir_train +'/%04d/' % epoch+ '%04dFlorinDS_MSE_FS_30_00_target_%d.jpg' % (epoch, batch_idx))
            if epoch % save_freq_model == 0:  torch.save(model.state_dict(), model_dir + 'FlorinDS_MSE_FS_30_Checkpoint_e%04d' % epoch)
            model_name='FlorinDS_MSE_FS_30_Checkpoint_e%04d' % epoch

            if epoch % save_freq == 0 and batch_idx == (len(train_loader) - 1):
                model_V= M.DeBlurNet().cpu()
                model_V.load_state_dict(torch.load(model_dir + model_name))
                my_val_loss, val_psnr = valid_loss_function(model_V.cpu(), val_loader, epoch, save_freq)
                print("Validation Loss=%.10f" % my_val_loss)
                plotter1V.plot('loss', 'Validation Loss', 'Epoch', epoch, my_val_loss)
                print("PSNR: " + str(val_psnr))
                plotter2V.plot("PSNR", 'Val_PSNR', 'Epoch', epoch, val_psnr)
            psnr=metrics.PSNR(target_img,out_img)
            print("PSNR: " + str(psnr))
        # ssim=metrics.SSIM(target_img,out_img)
        plotter1.plot('Loss', 'Train Loss', 'Epoch', epoch, final_loss)
        plotter2.plot("PSNR", 'Train_PSNR', 'Epoch', epoch, psnr)
        t1=time.time()
        total=t1-to
        print(total)
        # torch.save(model.state_dict(), model_dir + 'checkpoint_curr_e%04d' % epoch)