예제 #1
0
def plot_iterations(recos,
                    iters,
                    save_name=None,
                    fig_size=(18, 4.5),
                    cmap='pink'):
    """
    Plot several iterates of an iterative method
    :param recos: List of reconstructions
    :param iters: Iteration numbers
    """
    im, ax = plot_images(recos,
                         fig_size=fig_size,
                         rect=(0.0, 0.0, 1.0, 1.0),
                         xticks=[],
                         yticks=[],
                         vrange=(0., 0.9),
                         cbar=False,
                         interpolation='none',
                         cmap=cmap)

    for i in range(len(iters)):
        ax[i].set_title('Iteration: %d' % iters[i])

    plt.tight_layout()
    plt.tight_layout()

    if save_name:
        plt.savefig('%s.pdf' % save_name)

    plt.show()
예제 #2
0
파일: ct_tvadam.py 프로젝트: jleuschn/dival
def callback_func(iteration, reconstruction, loss):
    _, ax = plot_images([reconstruction, gt], fig_size=(10, 4))
    ax[0].set_xlabel('loss: {:f}'.format(loss))
    ax[0].set_title('TV iteration {:d}'.format(iteration))
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(TEST_SAMPLE))
    plt.show()
예제 #3
0
def plot_reconstructions(reconstructions,
                         titles,
                         ray_trafo,
                         obs,
                         gt,
                         save_name=None,
                         fig_size=(18, 4.5),
                         cmap='pink'):
    """
    Plots a ground-truth and several reconstructions
    :param reconstructors: List of Reconstructor objects to compute the reconstructions
    :param test_data: Data to apply the reconstruction methods
    """
    psnrs = [PSNR(reco, gt) for reco in reconstructions]
    ssims = [SSIM(reco, gt) for reco in reconstructions]

    l2_error0 = np.sqrt(
        np.sum(np.power(ray_trafo(gt).asarray() - obs.asarray(), 2)))
    l2_error = [
        np.sqrt(np.sum(np.power(ray_trafo(reco).asarray() - obs.asarray(), 2)))
        for reco in reconstructions
    ]

    # plot results
    im, ax = plot_images([
        gt,
    ] + reconstructions,
                         fig_size=fig_size,
                         rect=(0.0, 0.0, 1.0, 1.0),
                         xticks=[],
                         yticks=[],
                         vrange=(0.0, 0.9 * np.max(gt.asarray())),
                         cbar=False,
                         interpolation='none',
                         cmap=cmap)

    # set labels
    ax[0].set_title('Ground Truth')
    for j in range(len(reconstructions)):
        ax[j + 1].set_title(titles[j])
        ax[j + 1].set_xlabel(
            '$\ell_2$ data error: {:.4f}\nPSNR: {:.1f}, SSIM: {:.2f}'.format(
                l2_error[j], psnrs[j], ssims[j]))

    ax[0].set_xlabel('$\ell_2$ data error: {:.2f}'.format(l2_error0))

    plt.tight_layout()
    plt.tight_layout()

    if save_name:
        plt.savefig('%s.pdf' % save_name)
    plt.show()
예제 #4
0
reconstructor.load_hyper_params(hyper_params_path)

#%% expose FBP cache to reconstructor by assigning `fbp_dataset` attribute
# uncomment the next line to generate the cache files (~20 GB)
# generate_fbp_cache_files(dataset, ray_trafo, CACHE_FILES)
cached_fbp_dataset = get_cached_fbp_dataset(dataset, ray_trafo, CACHE_FILES)
dataset.fbp_dataset = cached_fbp_dataset

#%% train
# reduce the batch size here if the model does not fit into GPU memory
# reconstructor.batch_size = 16
reconstructor.train(dataset)

#%% evaluate
recos = []
psnrs = []
for obs, gt in test_data:
    reco = reconstructor.reconstruct(obs)
    recos.append(reco)
    psnrs.append(PSNR(reco, gt))

print('mean psnr: {:f}'.format(np.mean(psnrs)))

for i in range(3):
    _, ax = plot_images([recos[i], test_data.ground_truth[i]],
                        fig_size=(10, 4))
    ax[0].set_xlabel('PSNR: {:.2f}'.format(psnrs[i]))
    ax[0].set_title('FBPUNetReconstructor')
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(i))
예제 #5
0
    def plot_reconstruction(self,
                            task_ind,
                            sub_task_ind=0,
                            test_ind=-1,
                            plot_ground_truth=True,
                            **kwargs):
        """Plot the reconstruction at the specified index.
        Supports only 1d and 2d reconstructions.

        Parameters
        ----------
        task_ind : int
            Index of the task.
        sub_task_ind : int, optional
            Index of the sub-task (default ``0``).
        test_ind : sequence of int or int, optional
            Index in test data. If ``-1``, plot all reconstructions (the
            default).
        plot_ground_truth : bool, optional
            Whether to show the ground truth next to the reconstruction.
            The default is ``True``.
        kwargs : dict
            Keyword arguments that are passed to
            :func:`~dival.util.plot.plot_image` if the reconstruction is 2d.

        Returns
        -------
        ax_list : list of :class:`np.ndarray` of :class:`matplotlib.axes.Axes`
            The axes in which the reconstructions and eventually the ground
            truth were plotted.
        """
        row = self.results.loc[task_ind, sub_task_ind]
        test_data = row.at['test_data']
        reconstructor = row.at['reconstructor']
        ax_list = []
        if isinstance(test_ind, int):
            if test_ind == -1:
                test_ind = range(len(test_data))
            else:
                test_ind = [test_ind]
        for i in test_ind:
            title = 'reconstruction for task {}{}, test_data[{}]'.format(
                task_ind, '.{}'.format(sub_task_ind)
                if len(self.results.loc[task_ind]) > 1 else '', i)
            reconstruction = row.at['reconstructions'][i]
            ground_truth = test_data.ground_truth[i]
            if reconstruction is None:
                raise ValueError('reconstruction is `None`')
            if reconstruction.asarray().ndim > 2:
                print('only 1d and 2d reconstructions can be plotted')
                return
            if reconstruction.asarray().ndim == 1:
                x = reconstruction.space.points()
                _, ax = plt.subplots()
                ax.plot(x, reconstruction, label=reconstructor.name)
                if plot_ground_truth:
                    ax.plot(x, ground_truth, label='ground truth')
                ax.legend()
                ax.set_title(title)
                ax = np.array(ax)
            elif reconstruction.asarray().ndim == 2:
                if plot_ground_truth:
                    _, ax = plot_images([reconstruction, ground_truth],
                                        **kwargs)
                    ax[1].set_title('ground truth')
                else:
                    _, ax = plot_image(reconstruction, **kwargs)
                ax[0].set_title(reconstructor.name)
                ax[0].figure.suptitle(title)
            ax_list.append(ax)
        return ax_list
예제 #6
0
def callback_func(iteration, reconstruction, loss):
    _, ax = plot_images([reconstruction, gt],
                        fig_size=(10, 4))
    ax[0].set_xlabel('loss: {:f}'.format(loss))
    ax[0].set_title('DIP iteration {:d}'.format(iteration))
    ax[1].set_title('ground truth')
    ax[0].figure.suptitle('test sample {:d}'.format(TEST_SAMPLE))
    plt.show()

reconstructor = DeepImagePriorCTReconstructor(
    dataset.get_ray_trafo(impl=IMPL),
    callback_func=callback_func, callback_func_interval=100)

#%% obtain reference hyper parameters
if not check_for_params('diptv', 'lodopab'):
    download_params('diptv', 'lodopab')
params_path = get_params_path('diptv', 'lodopab')
reconstructor.load_params(params_path)

#%% evaluate
reco = reconstructor.reconstruct(obs)
psnr = PSNR(reco, gt)

print('psnr: {:f}'.format(psnr))
_, ax = plot_images([reco, gt],
                    fig_size=(10, 4))
ax[0].set_xlabel('PSNR: {:.2f}'.format(psnr))
ax[0].set_title('DeepImagePriorCTReconstructor')
ax[1].set_title('ground truth')
ax[0].figure.suptitle('test sample {:d}'.format(TEST_SAMPLE))
예제 #7
0
def plot_reconstructors_tests(reconstructors,
                              ray_trafo,
                              test_data,
                              save_name=None,
                              fig_size=(18, 4.5),
                              cmap='pink'):
    """
    Plots a ground-truth and several reconstructions
    :param reconstructors: List of Reconstructor objects to compute the reconstructions
    :param test_data: Data to apply the reconstruction methods
    """
    titles = []
    for reconstructor in reconstructors:
        titles.append(reconstructor.name)

    for i in range(len(test_data)):
        y_delta, x = test_data[i]

        # compute reconstructions and psnr and ssim measures
        recos = [r.reconstruct(y_delta) for r in reconstructors]
        l2_error = [
            np.sqrt(
                np.sum(
                    np.power(ray_trafo(reco).asarray() - y_delta.asarray(),
                             2))) for reco in recos
        ]
        l2_error0 = np.sqrt(
            np.sum(np.power(ray_trafo(x).asarray() - y_delta.asarray(), 2)))

        psnrs = [PSNR(reco, x) for reco in recos]
        ssims = [SSIM(reco, x) for reco in recos]

        # plot results
        im, ax = plot_images([
            x,
        ] + recos,
                             fig_size=fig_size,
                             rect=(0.0, 0.0, 1.0, 1.0),
                             ncols=4,
                             nrows=-1,
                             xticks=[],
                             yticks=[],
                             vrange=(0.0, 0.9 * np.max(x.asarray())),
                             cbar=False,
                             interpolation='none',
                             cmap=cmap)

        # set labels
        ax = ax.reshape(-1)
        ax[0].set_title('Ground Truth')
        ax[0].set_xlabel('$\ell_2$ data error: {:.2f}'.format(l2_error0))
        for j in range(len(recos)):
            ax[j + 1].set_title(titles[j])
            ax[j + 1].set_xlabel(
                '$\ell_2$ data error: {:.4f}\nPSNR: {:.1f}, SSIM: {:.2f}'.
                format(l2_error[j], psnrs[j], ssims[j]))

        plt.tight_layout()
        plt.tight_layout()

        if save_name:
            plt.savefig('%s-%d.pdf' % (save_name, i))
        plt.show()