def plot_result(result, gt, iteration):
    psnr = PSNR(result, gt)
    ssim = SSIM(result, gt)
    plt.imshow(result, cmap='gray')
    plt.axis('off')
    plt.title('%d: PSNR: %.2f SSIM: %.4f' % (iteration, psnr, ssim))
    plt.show()
 def callback_func(iteration, reconstruction, loss):
     global iter_history, loss_history, psnr_history, ssim_history, reco_history
     if iteration == 0:
         return
     iter_history.append(iteration)
     loss_history.append(loss)
     psnr_history.append(PSNR(reconstruction, gt))
     ssim_history.append(SSIM(reconstruction, gt))
     reco_history.append(reconstruction)
Example #3
0
def plot_reconstructions(reconstructions,
                         titles,
                         ray_trafo,
                         obs,
                         gt,
                         save_name=None,
                         fig_size=(18, 4.5),
                         cmap='pink'):
    """
    Plots a ground-truth and several reconstructions
    :param reconstructors: List of Reconstructor objects to compute the reconstructions
    :param test_data: Data to apply the reconstruction methods
    """
    psnrs = [PSNR(reco, gt) for reco in reconstructions]
    ssims = [SSIM(reco, gt) for reco in reconstructions]

    l2_error0 = np.sqrt(
        np.sum(np.power(ray_trafo(gt).asarray() - obs.asarray(), 2)))
    l2_error = [
        np.sqrt(np.sum(np.power(ray_trafo(reco).asarray() - obs.asarray(), 2)))
        for reco in reconstructions
    ]

    # plot results
    im, ax = plot_images([
        gt,
    ] + reconstructions,
                         fig_size=fig_size,
                         rect=(0.0, 0.0, 1.0, 1.0),
                         xticks=[],
                         yticks=[],
                         vrange=(0.0, 0.9 * np.max(gt.asarray())),
                         cbar=False,
                         interpolation='none',
                         cmap=cmap)

    # set labels
    ax[0].set_title('Ground Truth')
    for j in range(len(reconstructions)):
        ax[j + 1].set_title(titles[j])
        ax[j + 1].set_xlabel(
            '$\ell_2$ data error: {:.4f}\nPSNR: {:.1f}, SSIM: {:.2f}'.format(
                l2_error[j], psnrs[j], ssims[j]))

    ax[0].set_xlabel('$\ell_2$ data error: {:.2f}'.format(l2_error0))

    plt.tight_layout()
    plt.tight_layout()

    if save_name:
        plt.savefig('%s.pdf' % save_name)
    plt.show()
Example #4
0
def plot_reconstructors_tests(reconstructors,
                              ray_trafo,
                              test_data,
                              save_name=None,
                              fig_size=(18, 4.5),
                              cmap='pink'):
    """
    Plots a ground-truth and several reconstructions
    :param reconstructors: List of Reconstructor objects to compute the reconstructions
    :param test_data: Data to apply the reconstruction methods
    """
    titles = []
    for reconstructor in reconstructors:
        titles.append(reconstructor.name)

    for i in range(len(test_data)):
        y_delta, x = test_data[i]

        # compute reconstructions and psnr and ssim measures
        recos = [r.reconstruct(y_delta) for r in reconstructors]
        l2_error = [
            np.sqrt(
                np.sum(
                    np.power(ray_trafo(reco).asarray() - y_delta.asarray(),
                             2))) for reco in recos
        ]
        l2_error0 = np.sqrt(
            np.sum(np.power(ray_trafo(x).asarray() - y_delta.asarray(), 2)))

        psnrs = [PSNR(reco, x) for reco in recos]
        ssims = [SSIM(reco, x) for reco in recos]

        # plot results
        im, ax = plot_images([
            x,
        ] + recos,
                             fig_size=fig_size,
                             rect=(0.0, 0.0, 1.0, 1.0),
                             ncols=4,
                             nrows=-1,
                             xticks=[],
                             yticks=[],
                             vrange=(0.0, 0.9 * np.max(x.asarray())),
                             cbar=False,
                             interpolation='none',
                             cmap=cmap)

        # set labels
        ax = ax.reshape(-1)
        ax[0].set_title('Ground Truth')
        ax[0].set_xlabel('$\ell_2$ data error: {:.2f}'.format(l2_error0))
        for j in range(len(recos)):
            ax[j + 1].set_title(titles[j])
            ax[j + 1].set_xlabel(
                '$\ell_2$ data error: {:.4f}\nPSNR: {:.1f}, SSIM: {:.2f}'.
                format(l2_error[j], psnrs[j], ssims[j]))

        plt.tight_layout()
        plt.tight_layout()

        if save_name:
            plt.savefig('%s-%d.pdf' % (save_name, i))
        plt.show()
Example #5
0
                 [0, 4, 4, 4, 4]]

for skip in skip_channels:
    print('Skip Channels:')
    print(skip)
    params.channels = (ch,) * sc
    params.skip_channels = skip
    params.scales = sc
    reconstructor = DeepImagePriorReconstructor(ray_trafo=ray_trafo, hyper_params=params.dict, name='DIP')
    result = reconstructor.reconstruct(obs)
    results.append(result)

fig = plt.figure(figsize=(9.1, 8.3))
for i in range(len(results)):
    ax = fig.add_subplot(3, 4, i+1)
    psnr = PSNR(results[i], gt)
    ssim = SSIM(results[i], gt)
    plot_image(results[i], ax=ax, xticks=[], yticks=[], cmap='pink')
    if i < 8:
        ax.set_title('Channels: %d, Scales: %d' % (channels[i], scales[i]))
    else:
        ax.set_title('Skip: {}'.format(skip_channels[i-8]))
    ax.set_xlabel('PSNR: %.2f, SSIM: %.4f' % (psnr, ssim))

plt.tight_layout()
plt.tight_layout()

plt.savefig('ellipses_architectures.pdf')
plt.savefig('ellipses_architectures.pgf')
plt.show()
Example #6
0
    def train(self, dataset):
        if self.torch_manual_seed:
            torch.random.manual_seed(self.torch_manual_seed)

        self.init_transform(dataset=dataset)

        # create PyTorch datasets
        dataset_train = dataset.create_torch_dataset(
            part='train',
            reshape=((1, ) + dataset.space[0].shape,
                     (1, ) + dataset.space[1].shape),
            transform=self._transform)

        dataset_validation = dataset.create_torch_dataset(
            part='validation',
            reshape=((1, ) + dataset.space[0].shape,
                     (1, ) + dataset.space[1].shape))

        # reset model before training

        criterion = torch.nn.MSELoss()  #torch.nn.MSELoss()
        criterion_sum = torch.nn.MSELoss(reduction='sum')
        self.init_optimizer(dataset_train=dataset_train)

        # create PyTorch dataloaders
        shuffle = (dataset.supports_random_access()
                   if self.shuffle == 'auto' else self.shuffle)
        data_loaders = {
            'train':
            DataLoader(dataset_train,
                       batch_size=self.batch_size,
                       num_workers=self.num_data_loader_workers,
                       shuffle=shuffle,
                       pin_memory=True,
                       worker_init_fn=self.worker_init_fn),
            'validation':
            DataLoader(dataset_validation,
                       batch_size=self.batch_size,
                       num_workers=self.num_data_loader_workers,
                       shuffle=shuffle,
                       pin_memory=True,
                       worker_init_fn=self.worker_init_fn)
        }

        dataset_sizes = {
            'train': len(dataset_train),
            'validation': len(dataset_validation)
        }

        self.init_scheduler(dataset_train=dataset_train)
        if self._scheduler is not None:
            schedule_every_batch = isinstance(self._scheduler,
                                              (CyclicLR, OneCycleLR))

        best_model_wts = deepcopy(self.model.state_dict())
        best_psnr = 0

        if self.log_dir is not None:
            if not TENSORBOARD_AVAILABLE:
                raise ImportError(
                    'Missing tensorboard. Please install it or disable '
                    'logging by specifying `log_dir=None`.')
            writer = SummaryWriter(log_dir=self.log_dir, max_queue=0)
            validation_samples = dataset.get_data_pairs(
                'validation', self.log_num_validation_samples)

        self.model.to(self.device)
        self.model.train()

        for epoch in range(self.epochs):
            # Each epoch has a training and validation phase
            for phase in ['train', 'validation']:
                if phase == 'train':
                    self.model.train()  # Set model to training mode
                else:
                    self.model.eval()  # Set model to evaluate mode

                running_psnr = 0.0
                running_ssim = 0.0
                running_loss = 0.0
                running_size = 0
                num_iter = 0
                with tqdm(data_loaders[phase],
                          desc='epoch {:d}'.format(epoch + 1),
                          disable=not self.show_pbar) as pbar:
                    for inputs, labels in pbar:
                        num_iter += 1
                        if self.normalize_by_opnorm:
                            inputs = (1. / self.opnorm) * inputs
                        masked_inputs, mask = self.masker.mask(
                            inputs, num_iter % (self.masker.n_masks - 1))
                        #loss_ratio = torch.numel(mask) / mask.sum()
                        #mask *= loss_ratio
                        # fbp reconstruct
                        masked_inputs_fbp = np.zeros(labels.shape,
                                                     dtype=np.float32)
                        for i in range(len(inputs)):
                            masked_inputs_fbp[i, 0, :, :] = self.fbp_op(
                                masked_inputs[i, 0].numpy())
                        masked_inputs_fbp = torch.from_numpy(masked_inputs_fbp)

                        inputs_fbp = np.zeros(labels.shape, dtype=np.float32)
                        for i in range(len(inputs)):
                            inputs_fbp[i, 0, :, :] = self.fbp_op(
                                inputs[i, 0].numpy())
                        inputs_fbp = torch.from_numpy(inputs_fbp)

                        masked_inputs_fbp = masked_inputs_fbp.to(self.device)
                        inputs_fbp = inputs_fbp.to(self.device)
                        inputs = inputs.to(self.device)
                        labels = labels.to(self.device)
                        mask = mask.to(self.device)

                        # zero the parameter gradients
                        self._optimizer.zero_grad()

                        # forward
                        # track gradients only if in train phase
                        with torch.set_grad_enabled(phase == 'train'):
                            out_raw = self.model(inputs_fbp)
                            out_masked = self.model(masked_inputs_fbp)
                            proj_raw = self.ray_trafo_module(out_raw)
                            proj_masked = self.ray_trafo_module(out_masked)
                            l_rec = criterion(proj_raw, inputs)
                            l_inv = criterion_sum(proj_raw * mask, proj_masked
                                                  * mask) / mask.sum()
                            # print(mask.sum(), l_rec.item(), l_inv.item())
                            loss = l_rec + 2 * torch.sqrt(l_inv)

                            # backward + optimize only if in training phase
                            if phase == 'train':
                                loss.backward()
                                torch.nn.utils.clip_grad_norm_(
                                    self.model.parameters(), max_norm=1)
                                self._optimizer.step()
                                if (self._scheduler is not None
                                        and schedule_every_batch):
                                    self._scheduler.step()

                        for i in range(out_raw.shape[0]):
                            labels_ = labels[i, 0].detach().cpu().numpy()
                            outputs_ = out_raw[i, 0].detach().cpu().numpy()
                            running_psnr += PSNR(outputs_, labels_)
                            running_ssim += SSIM(outputs_, labels_)

                        # statistics
                        running_loss += loss.item() * out_raw.shape[0]
                        running_size += out_raw.shape[0]

                        pbar.set_postfix({
                            'phase': phase,
                            'loss': running_loss / running_size,
                            'psnr': running_psnr / running_size,
                            'ssim': running_ssim / running_size
                        })
                        if self.log_dir is not None and phase == 'train':
                            step = (epoch * ceil(
                                dataset_sizes['train'] / self.batch_size) +
                                    ceil(running_size / self.batch_size))
                            writer.add_scalar(
                                'loss/{}'.format(phase),
                                torch.tensor(running_loss / running_size),
                                step)
                            writer.add_scalar(
                                'psnr/{}'.format(phase),
                                torch.tensor(running_psnr / running_size),
                                step)
                            writer.add_scalar(
                                'ssim/{}'.format(phase),
                                torch.tensor(running_ssim / running_size),
                                step)

                    if (self._scheduler is not None
                            and not schedule_every_batch):
                        self._scheduler.step()

                    epoch_loss = running_loss / dataset_sizes[phase]
                    epoch_psnr = running_psnr / dataset_sizes[phase]
                    epoch_ssim = running_ssim / dataset_sizes[phase]

                    if self.log_dir is not None and phase == 'validation':
                        step = (epoch + 1) * ceil(
                            dataset_sizes['train'] / self.batch_size)
                        writer.add_scalar('loss/{}'.format(phase), epoch_loss,
                                          step)
                        writer.add_scalar('psnr/{}'.format(phase), epoch_psnr,
                                          step)
                        writer.add_scalar('ssim/{}'.format(phase), epoch_ssim,
                                          step)
                    # deep copy the model (if it is the best one seen so far)
                    if phase == 'validation' and epoch_psnr > best_psnr:
                        best_psnr = epoch_psnr
                        best_model_wts = deepcopy(self.model.state_dict())
                        if self.save_best_learned_params_path is not None:
                            self.save_learned_params(
                                self.save_best_learned_params_path)

                if (phase == 'validation' and self.log_dir is not None
                        and self.log_num_validation_samples > 0):
                    with torch.no_grad():
                        val_images = []
                        for (y, x) in validation_samples:
                            y = torch.from_numpy(np.asarray(y))[None, None].to(
                                self.device)
                            x = torch.from_numpy(np.asarray(x))[None, None].to(
                                self.device)
                            reco = self.model(y)
                            reco -= torch.min(reco)
                            reco /= torch.max(reco)
                            val_images += [reco, x]
                        writer.add_images(
                            'validation_samples',
                            torch.cat(val_images), (epoch + 1) *
                            (ceil(dataset_sizes['train'] / self.batch_size)),
                            dataformats='NCWH')

        print('Best val psnr: {:4f}'.format(best_psnr))
        self.model.load_state_dict(best_model_wts)
Example #7
0
from dival.measure import PSNR, SSIM
from dival.reconstructors.odl_reconstructors import FBPReconstructor
fbp_reconstructor = FBPReconstructor(dataset.ray_trafo,
                                     hyper_params={
                                         'filter_type': 'Hann',
                                         'frequency_scaling': 1.0
                                     })
for i, (obs, gt) in islice(enumerate(dataset.generator(part='train')), 3):
    reco = fbp_reconstructor.reconstruct(obs)
    reco = np.clip(reco, 0., 1.)
    _, ax = plot_images([reco, gt], fig_size=(10, 4))
    ax[0].figure.suptitle('train sample {:d}'.format(i))
    ax[0].set_title('FBP reconstruction')
    ax[1].set_title('Ground truth')
    psnr = PSNR(reco, gt)
    ssim = SSIM(reco, gt)
    ax[0].set_xlabel('PSNR: {:.2f}dB, SSIM: {:.3f}'.format(psnr, ssim))
    print('metrics for FBP reconstruction on sample {:d}:'.format(i))
    print('PSNR: {:.2f}dB, SSIM: {:.3f}'.format(psnr, ssim))
plt.show()

# %% simulate and store fan beam observations
SKIP_SIMULATION = False

if not SKIP_SIMULATION:
    from dival.util.input import input_yes_no
    print('start simulating and storing fan beam observations for all lodopab '
          'ground truth samples? [y]/n')
    if not input_yes_no():
        raise RuntimeError('cancelled by user')