def test(self):
        del self.WGANVGG
        # load
        self.WGANVGG_G = WGAN_VGG_generator().to(self.device)
        self.load_model(self.test_iters)

        # compute PSNR, SSIM, RMSE
        ori_psnr_avg, ori_ssim_avg, ori_rmse_avg = 0, 0, 0
        pred_psnr_avg, pred_ssim_avg, pred_rmse_avg = 0, 0, 0

        with torch.no_grad():
            for i, (x, y) in enumerate(self.data_loader):
                shape_ = x.shape[-1]
                x = x.unsqueeze(0).float().to(self.device)
                y = y.unsqueeze(0).float().to(self.device)

                pred = self.WGANVGG_G(x)

                # denormalize, truncate
                x = self.trunc(
                    self.denormalize_(x.view(shape_, shape_).cpu().detach()))
                y = self.trunc(
                    self.denormalize_(y.view(shape_, shape_).cpu().detach()))
                pred = self.trunc(
                    self.denormalize_(
                        pred.view(shape_, shape_).cpu().detach()))

                data_range = self.trunc_max - self.trunc_min

                original_result, pred_result = compute_measure(
                    x, y, pred, data_range)
                ori_psnr_avg += original_result[0]
                ori_ssim_avg += original_result[1]
                ori_rmse_avg += original_result[2]
                pred_psnr_avg += pred_result[0]
                pred_ssim_avg += pred_result[1]
                pred_rmse_avg += pred_result[2]

                # save result figure
                if self.result_fig:
                    self.save_fig(x, y, pred, i, original_result, pred_result)

                printProgressBar(i,
                                 len(self.data_loader),
                                 prefix="Compute measurements ..",
                                 suffix='Complete',
                                 length=25)
            print('\n')
            print(
                'Original\nPSNR avg: {:.4f} \nSSIM avg: {:.4f} \nRMSE avg: {:.4f}'
                .format(ori_psnr_avg / len(self.data_loader),
                        ori_ssim_avg / len(self.data_loader),
                        ori_rmse_avg / len(self.data_loader)))
            print(
                'After learning\nPSNR avg: {:.4f} \nSSIM avg: {:.4f} \nRMSE avg: {:.4f}'
                .format(pred_psnr_avg / len(self.data_loader),
                        pred_ssim_avg / len(self.data_loader),
                        pred_rmse_avg / len(self.data_loader)))
Beispiel #2
0
    def test(self):
        del self.WGANVGG
        # load
        self.WGANVGG_G = WGAN_VGG_generator().to(self.device)
        self.load_model()

        # compute PSNR, SSIM, RMSE
        ori_psnr_avg, ori_ssim_avg = 0, 0
        pred_psnr_avg, pred_ssim_avg = 0, 0

        with torch.no_grad():
            num_total_img = len(self.test_list)
            for img_idx, img_path in enumerate(self.test_list):
                img_name = os.path.basename(img_path)
                img_path = os.path.abspath(img_path)
                print("[{}/{}] processing {}".format(
                    img_idx, num_total_img, os.path.abspath(img_path)))

                gt_img_path = self.test_gt_list[img_idx]
                gt_img = imread(gt_img_path)
                input_img = imread(img_path)
                img_patch_dataset = ImageDataset(self.opt, input_img)
                img_patch_dataloader = DataLoader(
                    dataset=img_patch_dataset,
                    batch_size=self.opt.batch_size,
                    shuffle=False)

                img_shape = img_patch_dataset.get_img_shape()
                pad_img_shape = img_patch_dataset.get_padded_img_shape()

                out_list = []

                for i, x in enumerate(img_patch_dataloader):

                    x = x.float().to(self.device)

                    pred = self.WGANVGG_G(x)
                    pred = pred.to('cpu').detach().numpy()
                    out_list.append(pred)

                out = np.concatenate(out_list, axis=0)
                out = out.squeeze()

                img_name = 'out-' + img_name
                base_name = os.path.basename(self.opt.checkpoint_dir)
                test_result_dir = os.path.join(self.opt.test_result_dir,
                                               base_name)
                if not os.path.exists(test_result_dir):
                    os.makedirs(test_result_dir)
                dst_img_path = os.path.join(test_result_dir, img_name)

                out_img = mp.recon_patches(out, pad_img_shape[1],
                                           pad_img_shape[0],
                                           self.opt.patch_size,
                                           self.opt.patch_offset)
                out_img = mp.unpad_img(out_img, self.opt.patch_offset,
                                       img_shape)

                input_img = torch.Tensor(input_img)
                out_img = torch.Tensor(out_img)
                gt_img = torch.Tensor(gt_img)
                input_img = self.trunc(
                    self.denormalize_(input_img).cpu().detach())
                out_img = self.trunc(self.denormalize_(out_img).cpu().detach())
                gt_img = self.trunc(self.denormalize_(gt_img).cpu().detach())

                # x = self.trunc(self.denormalize_(x))
                # out_img = self.trunc(self.denormalize_(out_img))
                # gt_img = self.trunc(self.denormalize_(gt_img))

                data_range = self.trunc_max - self.trunc_min

                original_result, pred_result = compute_measure(
                    input_img, gt_img, out_img, data_range)

                op, oos, _ = original_result
                pp, ps, _ = pred_result

                ori_psnr_avg += op
                ori_ssim_avg += oos
                pred_psnr_avg += pp
                pred_ssim_avg += ps

                out_img = self.normalize_(out_img)
                out_img = out_img.cpu().numpy()
                imsave(dst_img_path, out_img)

            aop = ori_psnr_avg / (img_idx + 1)
            aos = ori_ssim_avg / (img_idx + 1)
            app = pred_psnr_avg / (img_idx + 1)
            aps = pred_ssim_avg / (img_idx + 1)
            print(
                "((ORIGIN)) PSNR : {:.5f}, SSIM : {:.5f}, ((PREP)) PSNR : {:.5f}, SSIM : {:.5f}"
                .format(aop, aos, app, aps))
Beispiel #3
0
    def train(self):
        train_losses = []
        total_iters = 0
        start_time = time.time()

        if not self.resume:
            self.set_checkpoint_dir()
            with open(self.opt.log_file, mode='w') as f:
                f.write(
                    'epoch, train__G_loss, train__P_loss, train__D_loss, train__GP_loss, PSNR, SSIM\n'
                )
            self.save_config()
        else:
            # self.set_checkpoint_dir()
            self.load_model()

        for epoch in range(self.start_epoch, self.num_epochs):

            total_d_loss = 0.0
            total_g_loss = 0.0
            total_p_loss = 0.0
            total_gp_loss = 0.0

            for iter_, (x, y) in enumerate(self.data_loader):
                total_iters += 1

                x = x.float().to(self.device)
                y = y.float().to(self.device)

                # add 1 channel
                # x = x.unsqueeze(0).float().to(self.device)
                # y = y.unsqueeze(0).float().to(self.device)
                # # patch training
                # if self.patch_size:
                #     x = x.view(-1, 1, self.patch_size, self.patch_size)
                #     y = y.view(-1, 1, self.patch_size, self.patch_size)

                # discriminator
                self.optimizer_d.zero_grad()
                self.WGANVGG.discriminator.zero_grad()
                for _ in range(self.n_d_train):
                    d_loss, gp_loss = self.WGANVGG.d_loss(x,
                                                          y,
                                                          gp=True,
                                                          return_gp=True)
                    d_loss.backward()
                    self.optimizer_d.step()

                # generator, perceptual loss
                self.optimizer_g.zero_grad()
                self.WGANVGG.generator.zero_grad()
                g_loss, p_loss = self.WGANVGG.g_loss(x,
                                                     y,
                                                     perceptual=True,
                                                     return_p=True)
                g_loss.backward()
                self.optimizer_g.step()

                train_losses.append([
                    g_loss.item() - p_loss.item(),
                    p_loss.item(),
                    d_loss.item() - gp_loss.item(),
                    gp_loss.item()
                ])

                # print
                if total_iters % self.print_iters == 0:
                    print(
                        "STEP [{}], EPOCH [{}/{}], ITER [{}/{}], TIME [{:.1f}s] >>> G_LOSS: {:.8f}, P_LOSS: {:.8f}, D_LOSS: {:.8f}, GD_LOSS: {:.8f}"
                        .format(total_iters, epoch, self.num_epochs, iter_ + 1,
                                len(self.data_loader),
                                time.time() - start_time,
                                g_loss.item() - p_loss.item() * 0.1,
                                p_loss.item(),
                                d_loss.item() - gp_loss.item(),
                                gp_loss.item()))
                # learning rate decay
                if total_iters % self.decay_iters == 0:
                    self.lr_decay()
                # save model
                # if total_iters % self.save_iters == 0:
                #     self.save_model(total_iters, g_loss.item())

                total_d_loss += d_loss.item()
                total_g_loss += g_loss.item()
                total_p_loss += p_loss.item()
                total_gp_loss += gp_loss.item()

            #save model
            self.save_model(epoch, g_loss.item())

            pred = self.WGANVGG.generator(x)
            original_result, pred_result = compute_measure(x, y, pred, 1)

            op, oos, _ = original_result
            pp, ps, _ = pred_result
            print(
                "((ORIGIN)) PSNR : {:.5f}, SSIM : {:.5f}, ((PREP)) PSNR : {:.5f}, SSIM : {:.5f}"
                .format(op, oos, pp, ps))

            total_d_loss = total_d_loss / iter_
            total_g_loss = total_g_loss / iter_
            total_p_loss = total_p_loss / iter_
            total_gp_loss = total_gp_loss / iter_

            with open(self.opt.log_file, mode='a') as f:
                f.write(
                    "{:d},{:.8f},{:.8f},{:.8f},{:.8f},{:.8f},{:.8f}\n".format(
                        epoch, total_g_loss, total_p_loss, total_d_loss,
                        total_gp_loss, pp, ps))
Beispiel #4
0
# =============================================================================
# Model 1
# Reconstruction with Laplacian Regularization
print('===========================================')
print('Laplacian Regularization...')
data_dir = "../data/EITData"
x_lap = callLapReg(data_dir=data_dir, y_test=test_data)

results = [test_images, x_lap]
titles = ['Truth', 'Lap. Reg']
dir_name = "./figures"
if not os.path.exists(dir_name):
    os.makedirs(dir_name)
    print('Create path : {}'.format(dir_name))
# Evalute reconstructed images with PSNR, SSIM, RMSE.
p_reg, s_reg, m_reg = compute_measure(test_images, x_lap, 1)
print('PSNR: {:.5f}\t SSIM: {:.5f} \t RMSE: {:.5f}'.format(
    p_reg, s_reg, m_reg))
show_image_matrix(dir_name + "/LapFigs.png",
                  results,
                  titles=titles,
                  indices=slice(0, 15))

# =============================================================================
# Model 2
# Total Variation with FISTA (https://sites.google.com/site/amirbeck314/software)
# Digital Object Identifier 10.1109/TIP.2009.2028250

# =============================================================================
# Model 3
# Lap. Reg. + U-net
Beispiel #5
0
X_fbp = torch.zeros_like(test_images)

for i in range(batch_size):
    sino = test_data[i].squeeze()
    X0 = iradon(sino, theta=theta)
    X_fbp[i] = torch.from_numpy(X0)

results = [test_images, X_fbp]
titles = ['Truth', 'LBP']
show_image_matrix(dir_name + "/LBP.png",
                  results,
                  titles=titles,
                  indices=slice(0, num_display))

# Evalute reconstructed images with PSNR, SSIM, RMSE.
p_reg, s_reg, m_reg = compute_measure(test_images, X_fbp, 1)
print('PSNR: {:.5f}\t SSIM: {:.5f} \t RMSE: {:.5f}'.format(
    p_reg, s_reg, m_reg))

# =============================================================================
# Model 2
# Total Variation with FISTA (https://sites.google.com/site/amirbeck314/software)
# Digital Object Identifier 10.1109/TIP.2009.2028250

# =============================================================================
# Model 3
# Lap. Reg. + U-net
# Use these parameters to steer the training
print('===========================================')
print('Lap. Reg. + U-net...')
use_cuda = True