def test(self): del self.WGANVGG # load self.WGANVGG_G = WGAN_VGG_generator().to(self.device) self.load_model(self.test_iters) # compute PSNR, SSIM, RMSE ori_psnr_avg, ori_ssim_avg, ori_rmse_avg = 0, 0, 0 pred_psnr_avg, pred_ssim_avg, pred_rmse_avg = 0, 0, 0 with torch.no_grad(): for i, (x, y) in enumerate(self.data_loader): shape_ = x.shape[-1] x = x.unsqueeze(0).float().to(self.device) y = y.unsqueeze(0).float().to(self.device) pred = self.WGANVGG_G(x) # denormalize, truncate x = self.trunc( self.denormalize_(x.view(shape_, shape_).cpu().detach())) y = self.trunc( self.denormalize_(y.view(shape_, shape_).cpu().detach())) pred = self.trunc( self.denormalize_( pred.view(shape_, shape_).cpu().detach())) data_range = self.trunc_max - self.trunc_min original_result, pred_result = compute_measure( x, y, pred, data_range) ori_psnr_avg += original_result[0] ori_ssim_avg += original_result[1] ori_rmse_avg += original_result[2] pred_psnr_avg += pred_result[0] pred_ssim_avg += pred_result[1] pred_rmse_avg += pred_result[2] # save result figure if self.result_fig: self.save_fig(x, y, pred, i, original_result, pred_result) printProgressBar(i, len(self.data_loader), prefix="Compute measurements ..", suffix='Complete', length=25) print('\n') print( 'Original\nPSNR avg: {:.4f} \nSSIM avg: {:.4f} \nRMSE avg: {:.4f}' .format(ori_psnr_avg / len(self.data_loader), ori_ssim_avg / len(self.data_loader), ori_rmse_avg / len(self.data_loader))) print( 'After learning\nPSNR avg: {:.4f} \nSSIM avg: {:.4f} \nRMSE avg: {:.4f}' .format(pred_psnr_avg / len(self.data_loader), pred_ssim_avg / len(self.data_loader), pred_rmse_avg / len(self.data_loader)))
def test(self): del self.WGANVGG # load self.WGANVGG_G = WGAN_VGG_generator().to(self.device) self.load_model() # compute PSNR, SSIM, RMSE ori_psnr_avg, ori_ssim_avg = 0, 0 pred_psnr_avg, pred_ssim_avg = 0, 0 with torch.no_grad(): num_total_img = len(self.test_list) for img_idx, img_path in enumerate(self.test_list): img_name = os.path.basename(img_path) img_path = os.path.abspath(img_path) print("[{}/{}] processing {}".format( img_idx, num_total_img, os.path.abspath(img_path))) gt_img_path = self.test_gt_list[img_idx] gt_img = imread(gt_img_path) input_img = imread(img_path) img_patch_dataset = ImageDataset(self.opt, input_img) img_patch_dataloader = DataLoader( dataset=img_patch_dataset, batch_size=self.opt.batch_size, shuffle=False) img_shape = img_patch_dataset.get_img_shape() pad_img_shape = img_patch_dataset.get_padded_img_shape() out_list = [] for i, x in enumerate(img_patch_dataloader): x = x.float().to(self.device) pred = self.WGANVGG_G(x) pred = pred.to('cpu').detach().numpy() out_list.append(pred) out = np.concatenate(out_list, axis=0) out = out.squeeze() img_name = 'out-' + img_name base_name = os.path.basename(self.opt.checkpoint_dir) test_result_dir = os.path.join(self.opt.test_result_dir, base_name) if not os.path.exists(test_result_dir): os.makedirs(test_result_dir) dst_img_path = os.path.join(test_result_dir, img_name) out_img = mp.recon_patches(out, pad_img_shape[1], pad_img_shape[0], self.opt.patch_size, self.opt.patch_offset) out_img = mp.unpad_img(out_img, self.opt.patch_offset, img_shape) input_img = torch.Tensor(input_img) out_img = torch.Tensor(out_img) gt_img = torch.Tensor(gt_img) input_img = self.trunc( self.denormalize_(input_img).cpu().detach()) out_img = self.trunc(self.denormalize_(out_img).cpu().detach()) gt_img = self.trunc(self.denormalize_(gt_img).cpu().detach()) # x = self.trunc(self.denormalize_(x)) # out_img = self.trunc(self.denormalize_(out_img)) # gt_img = self.trunc(self.denormalize_(gt_img)) data_range = self.trunc_max - self.trunc_min original_result, pred_result = compute_measure( input_img, gt_img, out_img, data_range) op, oos, _ = original_result pp, ps, _ = pred_result ori_psnr_avg += op ori_ssim_avg += oos pred_psnr_avg += pp pred_ssim_avg += ps out_img = self.normalize_(out_img) out_img = out_img.cpu().numpy() imsave(dst_img_path, out_img) aop = ori_psnr_avg / (img_idx + 1) aos = ori_ssim_avg / (img_idx + 1) app = pred_psnr_avg / (img_idx + 1) aps = pred_ssim_avg / (img_idx + 1) print( "((ORIGIN)) PSNR : {:.5f}, SSIM : {:.5f}, ((PREP)) PSNR : {:.5f}, SSIM : {:.5f}" .format(aop, aos, app, aps))
def train(self): train_losses = [] total_iters = 0 start_time = time.time() if not self.resume: self.set_checkpoint_dir() with open(self.opt.log_file, mode='w') as f: f.write( 'epoch, train__G_loss, train__P_loss, train__D_loss, train__GP_loss, PSNR, SSIM\n' ) self.save_config() else: # self.set_checkpoint_dir() self.load_model() for epoch in range(self.start_epoch, self.num_epochs): total_d_loss = 0.0 total_g_loss = 0.0 total_p_loss = 0.0 total_gp_loss = 0.0 for iter_, (x, y) in enumerate(self.data_loader): total_iters += 1 x = x.float().to(self.device) y = y.float().to(self.device) # add 1 channel # x = x.unsqueeze(0).float().to(self.device) # y = y.unsqueeze(0).float().to(self.device) # # patch training # if self.patch_size: # x = x.view(-1, 1, self.patch_size, self.patch_size) # y = y.view(-1, 1, self.patch_size, self.patch_size) # discriminator self.optimizer_d.zero_grad() self.WGANVGG.discriminator.zero_grad() for _ in range(self.n_d_train): d_loss, gp_loss = self.WGANVGG.d_loss(x, y, gp=True, return_gp=True) d_loss.backward() self.optimizer_d.step() # generator, perceptual loss self.optimizer_g.zero_grad() self.WGANVGG.generator.zero_grad() g_loss, p_loss = self.WGANVGG.g_loss(x, y, perceptual=True, return_p=True) g_loss.backward() self.optimizer_g.step() train_losses.append([ g_loss.item() - p_loss.item(), p_loss.item(), d_loss.item() - gp_loss.item(), gp_loss.item() ]) # print if total_iters % self.print_iters == 0: print( "STEP [{}], EPOCH [{}/{}], ITER [{}/{}], TIME [{:.1f}s] >>> G_LOSS: {:.8f}, P_LOSS: {:.8f}, D_LOSS: {:.8f}, GD_LOSS: {:.8f}" .format(total_iters, epoch, self.num_epochs, iter_ + 1, len(self.data_loader), time.time() - start_time, g_loss.item() - p_loss.item() * 0.1, p_loss.item(), d_loss.item() - gp_loss.item(), gp_loss.item())) # learning rate decay if total_iters % self.decay_iters == 0: self.lr_decay() # save model # if total_iters % self.save_iters == 0: # self.save_model(total_iters, g_loss.item()) total_d_loss += d_loss.item() total_g_loss += g_loss.item() total_p_loss += p_loss.item() total_gp_loss += gp_loss.item() #save model self.save_model(epoch, g_loss.item()) pred = self.WGANVGG.generator(x) original_result, pred_result = compute_measure(x, y, pred, 1) op, oos, _ = original_result pp, ps, _ = pred_result print( "((ORIGIN)) PSNR : {:.5f}, SSIM : {:.5f}, ((PREP)) PSNR : {:.5f}, SSIM : {:.5f}" .format(op, oos, pp, ps)) total_d_loss = total_d_loss / iter_ total_g_loss = total_g_loss / iter_ total_p_loss = total_p_loss / iter_ total_gp_loss = total_gp_loss / iter_ with open(self.opt.log_file, mode='a') as f: f.write( "{:d},{:.8f},{:.8f},{:.8f},{:.8f},{:.8f},{:.8f}\n".format( epoch, total_g_loss, total_p_loss, total_d_loss, total_gp_loss, pp, ps))
# ============================================================================= # Model 1 # Reconstruction with Laplacian Regularization print('===========================================') print('Laplacian Regularization...') data_dir = "../data/EITData" x_lap = callLapReg(data_dir=data_dir, y_test=test_data) results = [test_images, x_lap] titles = ['Truth', 'Lap. Reg'] dir_name = "./figures" if not os.path.exists(dir_name): os.makedirs(dir_name) print('Create path : {}'.format(dir_name)) # Evalute reconstructed images with PSNR, SSIM, RMSE. p_reg, s_reg, m_reg = compute_measure(test_images, x_lap, 1) print('PSNR: {:.5f}\t SSIM: {:.5f} \t RMSE: {:.5f}'.format( p_reg, s_reg, m_reg)) show_image_matrix(dir_name + "/LapFigs.png", results, titles=titles, indices=slice(0, 15)) # ============================================================================= # Model 2 # Total Variation with FISTA (https://sites.google.com/site/amirbeck314/software) # Digital Object Identifier 10.1109/TIP.2009.2028250 # ============================================================================= # Model 3 # Lap. Reg. + U-net
X_fbp = torch.zeros_like(test_images) for i in range(batch_size): sino = test_data[i].squeeze() X0 = iradon(sino, theta=theta) X_fbp[i] = torch.from_numpy(X0) results = [test_images, X_fbp] titles = ['Truth', 'LBP'] show_image_matrix(dir_name + "/LBP.png", results, titles=titles, indices=slice(0, num_display)) # Evalute reconstructed images with PSNR, SSIM, RMSE. p_reg, s_reg, m_reg = compute_measure(test_images, X_fbp, 1) print('PSNR: {:.5f}\t SSIM: {:.5f} \t RMSE: {:.5f}'.format( p_reg, s_reg, m_reg)) # ============================================================================= # Model 2 # Total Variation with FISTA (https://sites.google.com/site/amirbeck314/software) # Digital Object Identifier 10.1109/TIP.2009.2028250 # ============================================================================= # Model 3 # Lap. Reg. + U-net # Use these parameters to steer the training print('===========================================') print('Lap. Reg. + U-net...') use_cuda = True