示例#1
0
    recon_sart_30 = SartReconstructor('SART30', sart_n_iter=n_iter, sart_relaxation=0.30)
    recon_sart_15 = SartReconstructor('SART15', sart_n_iter=n_iter, sart_relaxation=0.15)
    recon_sart_07 = SartReconstructor('SART07', sart_n_iter=n_iter, sart_relaxation=0.07)
    recon_sart_05 = SartReconstructor('SART05', sart_n_iter=n_iter, sart_relaxation=0.05)
    recon_sart_03 = SartReconstructor('SART03', sart_n_iter=n_iter, sart_relaxation=0.03)
    recon_sart_02 = SartReconstructor('SART02', sart_n_iter=n_iter, sart_relaxation=0.02)

    recons = [
        recon_fbp,
        recon_sart_95,
        recon_sart_90,
        recon_sart_80,
        recon_sart_50,
        recon_sart_30,
        recon_sart_15,
        recon_sart_07,
        recon_sart_05,
        recon_sart_03,
        recon_sart_02,
    ]
    
    imgs = [ r.calc(sinogram, theta) for r in recons ]

    for r in recons:
        mse, psnr, ssim = r.eval(gt)
        print( "{}: MSE:{:.5f} PSNR:{:.5f} SSIM:{:.5f}".format(
            r.name, mse, psnr, ssim
        ))

    plot_grid( [gt]+imgs, FOCUS=FOCUS, save_name='art.png', dpi=500 )
示例#2
0
                                   sart_relaxation=0.15)
    recon_dip = DgrReconstructor(
        'DGR',
        dip_n_iter=4001,
        net='skip',
        lr=0.01,
        reg_std=1. / 100,
        w_proj_loss=1.00,
        # w_perceptual_loss=0.01,
        w_tv_loss=0.00,
        w_ssim_loss=0.00)

    img_sart = recon_sart.calc(sinogram, theta)
    recon_dip.set_for_metric(gt, img_sart, FOCUS=FOCUS, log_dir='log/')
    img_dip = recon_dip.calc(sinogram, theta)

    recons = [
        recon_sart,
        recon_dip,
    ]

    for r in recons:
        mse, psnr, ssim = r.eval(gt)
        print("Noise 25 {}: MSE:{:.5f} PSNR:{:.5f} SSIM:{:.5f}".format(
            r.name, mse, psnr, ssim))

    plot_grid([gt, img_sart, img_dip],
              FOCUS=FOCUS,
              save_name='noise30.png',
              dpi=500)
示例#3
0
    recon_dip1.set_for_metric(gt, img_sart, FOCUS=FOCUS, log_dir='log')
    img_dip1 = recon_dip1.calc(sinogram, theta)

    recon_dip2.set_for_metric(gt, img_sart, FOCUS=FOCUS, log_dir='log')
    img_dip2 = recon_dip2.calc(sinogram, theta)

    recon_dip3.set_for_metric(gt, img_sart, FOCUS=FOCUS, log_dir='log')
    img_dip3 = recon_dip3.calc(sinogram, theta)

    recons = [recon_fbp, recon_sart, recon_sart_tv,
              recon_dip1, recon_dip2, recon_dip3]

    for r in recons:
        mse, psnr, ssim = r.eval(gt)
        print( "{}: MSE:{:.5f} PSNR:{:.5f} SSIM:{:.5f}".format(
            r.name, mse, psnr, ssim
        ))

    for i, img in enumerate([gt, img_fbp, img_sart, img_sart_tv, img_dip1, img_dip2, img_dip3,]):
        plot_grid([img],
            FOCUS=None, save_name=res_name+'_'+str(i)+'_all.png', dpi=500
        )

    plot_grid([
            gt, img_fbp, img_sart, img_sart_tv, 
            img_dip1, img_dip2, img_dip3,],
            FOCUS=None, save_name=res_name+'_all.png', dpi=500
        )
            
示例#4
0
# gt, sinogram, theta, FOCUS = image_to_sparse_sinogram(fname, channel=1,
#         n_proj=128, size=512, angle1=0.0, angle2=180.0, noise_pow=12.0 )

# plt.figure()
# plt.imshow(sinogram, cmap='gray')#, vmin=0.0, vmax=1.0)
# plt.figure()
# plt.imshow(gt, cmap='gray', vmin=0.0, vmax=1.0)
# plt.figure()
# plt.imshow(iradon(sinogram, theta=theta), cmap='gray')
# plt.show()

recon_fbp = IRadonReconstructor('FBP')
fbp_img = recon_fbp.calc(sinogram, theta)

imgs = [
    gt,
    fbp_img,
    gt,
    fbp_img,
    gt,
    fbp_img,
]

# plot_grid(
#         imgs, FOCUS=None, show=True, number_of_rows=3
# )
# plot_grid(
#         imgs, FOCUS=FOCUS, show=True, number_of_rows=3
# )
plot_grid(imgs, ZOOM=FOCUS, show=True, number_of_rows=1, plot1d=50)
        SartBM3DReconstructor('SART+BM3Dsigma0.10',
                              sart_n_iter=n_iter,
                              sart_relaxation=0.15,
                              bm3d_sigma=0.10),
        SartBM3DReconstructor('SART+BM3Dsigma0.20',
                              sart_n_iter=n_iter,
                              sart_relaxation=0.15,
                              bm3d_sigma=0.20),
        SartBM3DReconstructor('SART+BM3Dsigma0.50',
                              sart_n_iter=n_iter,
                              sart_relaxation=0.15,
                              bm3d_sigma=0.50),
        SartBM3DReconstructor('SART+BM3Dsigma0.70',
                              sart_n_iter=n_iter,
                              sart_relaxation=0.15,
                              bm3d_sigma=0.70),
        SartBM3DReconstructor('SART+BM3Dsigma0.90',
                              sart_n_iter=n_iter,
                              sart_relaxation=0.15,
                              bm3d_sigma=0.90),
    ]

    imgs = [r.calc(sinogram, theta) for r in recons]

    for r in recons:
        mse, psnr, ssim = r.eval(gt)
        print("{}: MSE:{:.5f} PSNR:{:.5f} SSIM:{:.5f}".format(
            r.name, mse, psnr, ssim))

    plot_grid([gt] + imgs, FOCUS=FOCUS, save_name='conventional.png')
示例#6
0
            'DGRH',
            dip_n_iter=2501,
            net='skipV2',
            lr=0.01,
            reg_std=1. / 100,
            w_proj_loss=0.98,
            # w_perceptual_loss=0.01,
            w_tv_loss=0.01,
            w_ssim_loss=0.01)

        img_sart = recon_sart.calc(sinogram, theta)
        recon_dip.set_for_metric(gt, img_sart, FOCUS=FOCUS, log_dir='log/')
        recon_dip2.set_for_metric(gt, img_sart, FOCUS=FOCUS, log_dir='log/')
        img_dip = recon_dip.calc(sinogram, theta)
        img_dip2 = recon_dip2.calc(sinogram, theta)

        recons = [recon_sart, recon_dip, recon_dip2]

        for r in recons:
            mse, psnr, ssim = r.eval(gt)
            print("{}:{}: MSE:{:.5f} PSNR:{:.5f} SSIM:{:.5f}".format(
                dose, r.name, mse, psnr, ssim))

        plot_grid([img_sart, img_dip, img_dip2],
                  FOCUS=FOCUS,
                  save_name='dose' + str(dose) + '.png',
                  dpi=500)
        imgs.append(img_dip)

    plot_grid([gt] + imgs, FOCUS=FOCUS, save_name='dose.png', dpi=500)
示例#7
0
def test(fname, label, n_proj=32, noise_pow=25.0):

    
    dgr_iter = 4000
    lr = 0.01
    net = 'skip'
    noise_std = 1./100

    gt, sinogram, theta, FOCUS = image_to_sparse_sinogram(fname,
                                    channel=1, n_proj=n_proj, size=512,
                                    angle1=0.0, angle2=180.0, noise_pow=noise_pow)

    logging.warning('Starting')
    logging.warning('fname: %s %s',label, fname)
    logging.warning('n_proj: %s', n_proj)
    logging.warning('noise_pow: %s', noise_pow)
    logging.warning('dgr_n_iter: %s', dgr_iter)
    logging.warning('dgr_lr: %s', lr)
    logging.warning('dgr_net: %s', net)
    logging.warning('dgr_noise_std: %s', noise_std)

    recons = [
        IRadonReconstructor('FBP'),
        SartReconstructor('SART', sart_n_iter=40, sart_relaxation=0.15),
        SartTVReconstructor('SART+TV', 
                                    sart_n_iter=40, sart_relaxation=0.15,
                                    tv_weight=0.5, tv_n_iter=100),
        SartBM3DReconstructor('SART+BM3D', 
                                    sart_n_iter=40, sart_relaxation=0.15,
                                    bm3d_sigma=0.5),    
        DgrReconstructor('DIP_1.00_0.00_0.00_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=1.0,
                         w_perceptual_loss=0.0,
                         w_tv_loss=0.0
                         ),
        DgrReconstructor('DIP_0.99_0.01_0.00_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.99,
                         w_perceptual_loss=0.01,
                         w_tv_loss=0.0
                         ),
        DgrReconstructor('DIP_0.90_0.10_0.00_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.90,
                         w_perceptual_loss=0.10,
                         w_tv_loss=0.0
                         ),
        DgrReconstructor('DIP_0.50_0.50_0.00_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.5,
                         w_perceptual_loss=0.5,
                         w_tv_loss=0.0
                         ),
        DgrReconstructor('DIP_0.10_0.90_0.00_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.10,
                         w_perceptual_loss=0.90,
                         w_tv_loss=0.0
                         ),
        DgrReconstructor('DIP_0.01_0.99_0.00_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.01,
                         w_perceptual_loss=0.99,
                         w_tv_loss=0.0
                         ),
        DgrReconstructor('DIP_0.00_1.00_0.00_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.0,
                         w_perceptual_loss=1.0,
                         w_tv_loss=0.0
                         ),
        DgrReconstructor('DIP_0.99_0.00_0.01_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.99,
                         w_perceptual_loss=0.0,
                         w_tv_loss=0.01
                         ),
        DgrReconstructor('DIP_0.90_0.00_0.10_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.9,
                         w_perceptual_loss=0.0,
                         w_tv_loss=0.1
                         ),
        DgrReconstructor('DIP_0.50_0.00_0.50_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.5,
                         w_perceptual_loss=0.0,
                         w_tv_loss=0.5
                         ),
        DgrReconstructor('DIP_0.10_0.00_0.90_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.1,
                         w_perceptual_loss=0.0,
                         w_tv_loss=0.9
                         ),
        DgrReconstructor('DIP_0.01_0.00_0.99_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.01,
                         w_perceptual_loss=0.0,
                         w_tv_loss=0.99
                         ),
        DgrReconstructor('DIP_0.00_0.00_1.0_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.00,
                         w_perceptual_loss=0.0,
                         w_tv_loss=1.0
                         ),



        DgrReconstructor('DIP_0.33_0.33_0.33_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.33,
                         w_perceptual_loss=0.33,
                         w_tv_loss=0.33
                         ),
        DgrReconstructor('DIP_0.8_0.10_0.10_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.8,
                         w_perceptual_loss=0.1,
                         w_tv_loss=0.1
                         ),
        DgrReconstructor('DIP_0.98_0.01_0.01_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.98,
                         w_perceptual_loss=0.01,
                         w_tv_loss=0.01
                         ),

        DgrReconstructor('DIP_0.10_0.80_0.10_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.10,
                         w_perceptual_loss=0.80,
                         w_tv_loss=0.10
                         ),
        DgrReconstructor('DIP_0.01_0.98_0.01_0.00',
                         dip_n_iter=dgr_iter,
                         net=net,
                         lr=lr,
                         reg_std=noise_std,
                         w_proj_loss=0.01,
                         w_perceptual_loss=0.98,
                         w_tv_loss=0.01
                         ),

    ]

    img_sart_bm3d = recons[3].calc(sinogram, theta)

    imgs = []
    for recon in recons:
        if type(recon) == DgrReconstructor:
            recon.set_for_metric(gt, img_sart_bm3d, FOCUS=FOCUS, log_dir='../log/dip')
        imgs.append(recon.calc(sinogram))
        mse, psnr, ssim = recon.eval(gt)
        recon.save_result()
        logstr = "{}: MSE:{:.5f} PSNR:{:.5f} SSIM:{:.5f}".format(
            recon.name, mse, psnr, ssim
        )
        logging.info(logstr)

    plot_grid([gt] + imgs,
              FOCUS=FOCUS, save_name=label+'.png', dpi=500)

    logging.warning('Done. Results saved as %s', label+'.png')
示例#8
0
        recon_sart,
        recon_sart_tv,
        recon_bm3d,
        recon_dip,
        recon_dip_rand,
    ]
    #   recon_n2self_selfsuper,
    #   recon_n2self_learned_single,
    #   recon_n2self_learned_selfsuper]

    for r in recons:
        mse, psnr, ssim = r.eval(gt)
        print("{}: MSE:{:.5f} PSNR:{:.5f} SSIM:{:.5f}".format(
            r.name, mse, psnr, ssim))

    plot_grid(
        [
            gt,
            img_fbp,
            img_sart,
            img_sart_tv,
            img_sart_bm3d,
            img_dip,
            img_dip_rand,
        ],
        #img_n2self_selfsuper,
        #img_n2self_learned_single, img_n2self_learned_selfsuper],
        FOCUS=FOCUS,
        save_name='all.png',
        dpi=500)
示例#9
0
            'DGRH',
            dip_n_iter=2501,
            net='skipV2',
            lr=0.01,
            reg_std=1. / 100,
            w_proj_loss=0.98,
            # w_perceptual_loss=0.01,
            w_tv_loss=0.01,
            w_ssim_loss=0.01)

        img_sart = recon_sart.calc(sinogram, theta)
        recon_dip.set_for_metric(gt, img_sart, FOCUS=FOCUS, log_dir='log/')
        recon_dip2.set_for_metric(gt, img_sart, FOCUS=FOCUS, log_dir='log/')
        img_dip = recon_dip.calc(sinogram, theta)
        img_dip2 = recon_dip2.calc(sinogram, theta)

        recons = [recon_sart, recon_dip, recon_dip2]

        for r in recons:
            mse, psnr, ssim = r.eval(gt)
            print("{}:{}: MSE:{:.5f} PSNR:{:.5f} SSIM:{:.5f}".format(
                view, r.name, mse, psnr, ssim))

        plot_grid([img_sart, img_dip, img_dip2],
                  FOCUS=FOCUS,
                  save_name='view' + str(view) + '.png',
                  dpi=500)
        imgs.append(img_dip)

    plot_grid([gt] + imgs, FOCUS=FOCUS, save_name='view.png', dpi=500)
示例#10
0
        w_ssim_loss=0.01,
        channels=[64, 128, 256],
    )

    img_sart = recon_sart.calc(sinogram, theta)
    img_sart_tv = recon_sart_tv.calc(sinogram, theta)
    recon_dip1.set_for_metric(gt, img_sart, FOCUS=FOCUS, log_dir='log/')
    img_dip1 = recon_dip1.calc(sinogram, theta)
    recon_dip2.set_for_metric(gt, img_sart, FOCUS=FOCUS, log_dir='log/')
    img_dip2 = recon_dip2.calc(sinogram, theta)
    recon_dip3.set_for_metric(gt, img_sart, FOCUS=FOCUS, log_dir='log/')
    img_dip3 = recon_dip3.calc(sinogram, theta)

    recons = [
        recon_sart,
        recon_sart_tv,
        recon_dip1,
        recon_dip2,
        recon_dip3,
    ]

    for r in recons:
        mse, psnr, ssim = r.eval(gt)
        print("{}: MSE:{:.5f} PSNR:{:.5f} SSIM:{:.5f}".format(
            r.name, mse, psnr, ssim))

    plot_grid([gt, img_sart, img_sart_tv, img_dip1, img_dip2, img_dip3],
              FOCUS=FOCUS,
              save_name='net_arc.png',
              dpi=500)
示例#11
0
    def calc(self, projs, theta):
        # recon params
        self.N_PROJ = len(theta)
        self.theta = torch.from_numpy(theta).to(self.DEVICE)
        self.radon = Radon(self.IMAGE_SIZE, self.theta, True).to(self.DEVICE)
        self.iradon = IRadon(self.IMAGE_SIZE, self.theta, True).to(self.DEVICE)

        # start recon
        if not os.path.exists(self.log_dir):
            os.mkdir(self.log_dir)

        net = self._get_net()

        x_initial = np_to_torch(self.noisy).type(self.DTYPE)
        img_gt_torch = np_to_torch(self.gt).type(self.DTYPE)
        net_input = torch.rand(1, self.INPUT_DEPTH, 
                    self.IMAGE_SIZE, self.IMAGE_SIZE).type(self.DTYPE)
        net_input_saved = net_input.detach().clone()
        noise = net_input.detach().clone()

        # Compute number of parameters
        s  = sum([np.prod(list(p.size())) for p in net.parameters()]); 
        print ('Number of params: %d' % s)

        # Optimizer
        optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)
        cur_lr = self.lr
        projs = np_to_torch(projs).type(self.DTYPE)#self.radon(img_gt_torch).detach().clone()
        #np_to_torch(projs).type(self.DTYPE) # 
        # Iterations
        loss_hist = []
        rmse_hist = []
        ssim_hist = []
        ssim_noisy_hist = []
        psnr_hist = []
        psnr_noisy_hist = []
        best_network = None
        best_result = None

        print('Reconstructing with DIP...')
        for i in tqdm(range(self.n_iter)):
            
            # iter
            optimizer.zero_grad()

            if self.reg_std > 0:
                net_input = net_input_saved + (noise.normal_() * self.reg_std)

            x_iter = net(net_input)
            loss, (proj_l, percep_l, tv_l, ssim_l) = self._calc_loss(x_iter, projs, x_initial)
            
            loss.backward()
            
            optimizer.step()

            # metric
            if i % self.SHOW_EVERY == 0:
                x_iter_npy = np.clip(torch_to_np(x_iter), 0, 1).astype(np.float64)

                rmse_hist.append(
                    mean_squared_error(x_iter_npy, self.gt))
                ssim_hist.append(
                    structural_similarity(x_iter_npy, self.gt, multichannel=False)
                )
                ssim_noisy_hist.append(
                    structural_similarity(x_iter_npy, self.noisy, multichannel=False)
                )
                psnr_hist.append(
                    peak_signal_noise_ratio(x_iter_npy, self.gt)
                )
                psnr_noisy_hist.append(
                    peak_signal_noise_ratio(x_iter_npy, self.noisy)
                )
                loss_hist.append(loss.item())
                print('{}/{}- psnr: {:.3f} - psnr_noisy: {:.3f} - ssim: {:.3f} - ssim_noisy: {:.3f} - rmse: {:.5f} - loss: {:.5f} '.format(
                    self.name, i, psnr_hist[-1], psnr_noisy_hist[-1], ssim_hist[-1], ssim_noisy_hist[-1], rmse_hist[-1], loss_hist[-1]
                ))
                #print( proj_l.item(), ssim_l.item())

                # if psnr_noisy_hist[-1] / max(psnr_noisy_hist) < 0.92:
                #     print('Falling back to previous checkpoint.')
                #     for g in optimizer.param_groups:
                #         g['lr'] = cur_lr / 10.0
                #     cur_lr = cur_lr / 10.0
                #     print("optimizer.lr", cur_lr)
                #     # load network
                #     for new_param, net_param in zip(best_network, net.parameters()):
                #         net_param.data.copy_(new_param.cuda())

                if i > 2:
                    if loss_hist[-1] < min(loss_hist[0:-1]):
                        # save network
                        best_network = [x.detach().cpu() for x in net.parameters()]
                        best_result = x_iter_npy.copy() 
                plot_grid([x_iter_npy], self.FOCUS, save_name=self.log_dir+'/{}.png'.format(i))

        self.image_r = best_result
        return self.image_r
示例#12
0
        # w_perceptual_loss=0.01,
        w_tv_loss=0.01,
        w_ssim_loss=0.01)

    img_sart = recon_sart.calc(sinogram, theta)
    img_sart_tv = recon_sart_tv.calc(sinogram, theta)
    recon_dip1.set_for_metric(gt, img_sart, FOCUS=FOCUS, log_dir='log/')
    img_dip1 = recon_dip1.calc(sinogram, theta)
    recon_dip2.set_for_metric(gt, img_sart, FOCUS=FOCUS, log_dir='log/')
    img_dip2 = recon_dip2.calc(sinogram, theta)
    recon_dip3.set_for_metric(gt, img_sart, FOCUS=FOCUS, log_dir='log/')
    img_dip3 = recon_dip3.calc(sinogram, theta)

    recons = [
        recon_sart,
        recon_sart_tv,
        recon_dip1,
        recon_dip2,
        recon_dip3,
    ]

    for r in recons:
        mse, psnr, ssim = r.eval(gt)
        print("{}: MSE:{:.5f} PSNR:{:.5f} SSIM:{:.5f}".format(
            r.name, mse, psnr, ssim))

    plot_grid([gt, img_sart, img_sart_tv, img_dip1, img_dip2, img_dip3],
              FOCUS=FOCUS,
              save_name='num_params.png',
              dpi=500)
示例#13
0
    def _calc_unsupervised(self, projs, theta):
        if not os.path.exists(self.log_dir):
            os.mkdir(self.log_dir)
            
        projs = np_to_torch(projs).type(self.DTYPE)
        #norm = transforms.Normalize(projs_torch[0].mean((1,2)), projs_torch[0].std((1,2)))

        for i in tqdm(range(self.n_iter)):
            
            # iter
            self.optimizer.zero_grad()

            # train
            if i % self.SHOW_EVERY != 0:
                self.net.train()
                self.filter.train()
                net_input, mask = self.masker.mask( projs, self.i_iter % (self.masker.n_masks - 1) )
                x_iter = self.net(
                    self.ir(net_input)
                )
                loss = self.mse(
                    self.r(x_iter)*mask,
                    projs*mask
                )
            # val
            else:
                self.net.eval()
                self.filter.eval()
                x_iter = self.net(
                    self.ir(projs)
                )
                loss = self.mse(
                    self.r(x_iter),
                    projs
                )

            loss.backward()
            self.optimizer.step()

            # metric
            if i % self.SHOW_EVERY == 0:
                x_iter_npy = np.clip(torch_to_np(x_iter), 0, 1).astype(np.float64)
                self.rmse_hist.append(
                    mean_squared_error(x_iter_npy, self.gt))
                self.ssim_hist.append(
                    structural_similarity(x_iter_npy, self.gt, multichannel=False)
                )
                self.psnr_hist.append(
                    peak_signal_noise_ratio(x_iter_npy, self.gt)
                )
                self.loss_hist.append(loss.item())
                print('{}/{}- psnr: {:.3f} - ssim: {:.3f} - rmse: {:.5f} - loss: {:.5f} '.format(
                    self.name, i, self.psnr_hist[-1], self.ssim_hist[-1], self.rmse_hist[-1], self.loss_hist[-1]
                ))

                if i > 2:
                    if self.loss_hist[-1] < min(self.loss_hist[0:-1]):
                        # save network
                        # best_network = [x.detach().cpu() for x in net.parameters()]
                        self.best_result = x_iter_npy.copy() 
                else:
                    self.best_result = x_iter_npy.copy()
                plot_grid([self.gt, x_iter_npy], self.FOCUS, save_name=self.log_dir+'/{}.png'.format(i))

        self.image_r = self.best_result
        return self.image_r