Ejemplo n.º 1
0
    def get_data_pairs(self, part='train', n=None):
        """
        Return first samples from data part as :class:`.DataPairs` object.

        Only supports datasets with two elements per sample.``

        Parameters
        ----------
        part : {``'train'``, ``'validation'``, ``'test'``}, optional
            The data part. Default is ``'train'``.
        n : int, optional
            Number of pairs (from beginning). If `None`, all available data
            is used (the default).
        """
        if self.get_num_elements_per_sample() != 2:
            raise ValueError(
                '`get_data_pairs` only supports datasets with'
                '2 elements per sample, this dataset has {:d}'.format(
                    self.get_num_elements_per_sample()))
        gen = self.generator(part=part)
        observations, ground_truth = [], []
        for obs, gt in islice(gen, n):
            observations.append(obs)
            ground_truth.append(gt)
        name = '{} part{}'.format(part,
                                  ' 0:{:d}'.format(n) if n is not None else '')
        data_pairs = DataPairs(observations, ground_truth, name=name)
        return data_pairs
Ejemplo n.º 2
0
    def test_option_save_iterates(self):
        domain = odl.uniform_discr([0, 0], [1, 1], (1, 1))
        ground_truth = domain.one()
        observation = domain.one()

        # reconstruct 1., iterates 0., 0.5, 0.75, 0.875, ...

        class DummyReconstructor(StandardIterativeReconstructor):
            def _setup(self, observation):
                self.setup_var = 'dummy_val'

            def _compute_iterate(self, observation, reco_previous, out):
                out[:] = 0.5 * (observation + reco_previous)

        test_data = DataPairs([observation], [ground_truth])
        tt = TaskTable()
        r = DummyReconstructor(reco_space=domain)
        hyper_param_choices = {'iterations': [10]}
        options = {'save_iterates': True}
        tt.append(r, test_data, hyper_param_choices=hyper_param_choices,
                  options=options)

        results = tt.run()
        self.assertAlmostEqual(
            1., results.results['misc'][0, 0]['iterates'][0][2][0, 0],
            delta=0.2)
        self.assertNotAlmostEqual(
            1., results.results['misc'][0, 0]['iterates'][0][1][0, 0],
            delta=0.2)
Ejemplo n.º 3
0
    def test_option_save_best_reconstructor_reuse_iterates(self):
        # test 'save_best_reconstructor' option together with 'iterations' in
        # hyper_param_choices, because `run` has performance optimization for
        # it (with the default argument ``reuse_iterates=True``)
        domain = odl.uniform_discr([0, 0], [1, 1], (2, 2))
        ground_truth = domain.element([[1, 0], [0, 1]])
        observation = domain.element([[0, 0], [0, 0]])

        # Reconstruct [[1, 0], [0, 1]], iterates are
        # [[0, 0], [0, 0]], [[.1, .1], [.1, .1]], [[.2, .2], [.2, .2]], ...
        # Best will be [[.5, .5], [.5, .5]].

        class DummyReconstructor(StandardIterativeReconstructor):
            def _setup(self, observation):
                self.setup_var = 'dummy_val'

            def _compute_iterate(self, observation, reco_previous, out):
                out[:] = reco_previous + 0.1

        test_data = DataPairs([observation], [ground_truth])
        tt = TaskTable()
        r = DummyReconstructor(reco_space=domain)
        hyper_param_choices = {'iterations': list(range(10))}
        known_best_choice = {'iterations': 5}
        path = 'dummypath'
        options = {'save_best_reconstructor': {'path': path, 'measure': PSNR}}
        tt.append(r,
                  test_data,
                  measures=[PSNR],
                  hyper_param_choices=hyper_param_choices,
                  options=options)

        class ExtStringIO(StringIO):
            def __init__(self, ext, f, *args, **kwargs):
                self.ext = ext
                self.f = f
                super().__init__(*args, **kwargs)
                self.ext[self.f] = self.getvalue()

            def close(self):
                self.ext[self.f] = self.getvalue()
                super().close()

        ext = {}
        with patch('dival.reconstructors.reconstructor.open',
                   lambda f, *a, **kw: ExtStringIO(ext, f)):
            tt.run()
        self.assertIn(path + '_hyper_params.json', ext)
        self.assertDictEqual(json.loads(ext[path + '_hyper_params.json']),
                             known_best_choice)
Ejemplo n.º 4
0
    def test_iterations_hyper_param_choices(self):
        # test 'iterations' in hyper_param_choices, because `run` has
        # performance optimization for it
        domain = odl.uniform_discr([0, 0], [1, 1], (1, 1))
        ground_truth = domain.one()
        observation = domain.one()

        # reconstruct 1., iterates 0., 0.5, 0.75, 0.875, ...

        class DummyReconstructor(StandardIterativeReconstructor):
            def _setup(self, observation):
                self.setup_var = 'dummy_val'

            def _compute_iterate(self, observation, reco_previous, out):
                out[:] = 0.5 * (observation + reco_previous)

        test_data = DataPairs([observation], [ground_truth])
        tt = TaskTable()
        r = DummyReconstructor(reco_space=domain)
        hyper_param_choices = {'iterations': [2, 3, 10]}
        tt.append(r, test_data, hyper_param_choices=hyper_param_choices)

        iters = []
        r.callback = CallbackStore(iters)
        results = tt.run(reuse_iterates=True)
        self.assertAlmostEqual(1.,
                               results.results['reconstructions'][0, 1][0][0,
                                                                           0],
                               delta=0.2)
        self.assertNotAlmostEqual(1.,
                                  results.results['reconstructions'][0,
                                                                     0][0][0,
                                                                           0],
                                  delta=0.2)
        self.assertEqual(len(iters), max(hyper_param_choices['iterations']))
        print(results.results['misc'])

        iters2 = []
        r.callback = CallbackStore(iters2)
        results2 = tt.run(reuse_iterates=False)
        self.assertAlmostEqual(1.,
                               results2.results['reconstructions'][0, 1][0][0,
                                                                            0],
                               delta=0.2)
        self.assertNotAlmostEqual(1.,
                                  results2.results['reconstructions'][0,
                                                                      0][0][0,
                                                                            0],
                                  delta=0.2)
        self.assertEqual(len(iters2), sum(hyper_param_choices['iterations']))
Ejemplo n.º 5
0
 def test(self):
     reco_space = odl.uniform_discr([0, 0], [1, 1], (2, 2))
     observation_space = odl.uniform_discr([0, 0], [1, 1], (1, 1))
     ground_truth = [reco_space.one(), reco_space.zero()]
     observations = [observation_space.one(), observation_space.zero()]
     data_pairs = DataPairs(observations, ground_truth)
     for obs_true, gt_true, (obs, gt) in zip(observations, ground_truth,
                                             data_pairs):
         self.assertEqual(obs, obs_true)
         self.assertEqual(gt, gt_true)
     self.assertEqual(len(data_pairs), len(observations))
     for obs_true, gt_true, (obs, gt) in zip(observations[::-1],
                                             ground_truth[::-1],
                                             data_pairs[::-1]):
         self.assertEqual(obs, obs_true)
         self.assertEqual(gt, gt_true)
Ejemplo n.º 6
0
 def test(self):
     reco_space = odl.uniform_discr(
         min_pt=[-64, -64], max_pt=[64, 64], shape=[128, 128])
     phantom = odl.phantom.shepp_logan(reco_space, modified=True)
     geometry = odl.tomo.parallel_beam_geometry(reco_space, 30)
     ray_trafo = odl.tomo.RayTransform(reco_space, geometry, impl='skimage')
     proj_data = ray_trafo(phantom)
     observation = np.asarray(
         proj_data +
         np.random.normal(loc=0., scale=2., size=proj_data.shape))
     test_data = DataPairs(observation, phantom)
     tt = TaskTable()
     fbp_reconstructor = FBPReconstructor(ray_trafo, hyper_params={
         'filter_type': 'Hann',
         'frequency_scaling': 0.8})
     tt.append(fbp_reconstructor, test_data, measures=[PSNR, SSIM])
     tt.run()
     self.assertGreater(
         tt.results.results['measure_values'][0, 0]['psnr'][0], 15.)
Ejemplo n.º 7
0
np.random.seed(0)

# %% data
reco_space = odl.uniform_discr(min_pt=[-20, -20],
                               max_pt=[20, 20],
                               shape=[300, 300],
                               dtype='float32')
phantom = odl.phantom.shepp_logan(reco_space, modified=True)
ground_truth = phantom

geometry = odl.tomo.cone_beam_geometry(reco_space, 40, 40, 360)
ray_trafo = odl.tomo.RayTransform(reco_space, geometry, impl='astra_cpu')
proj_data = ray_trafo(phantom)
observation = (proj_data + np.random.poisson(0.3, proj_data.shape)).asarray()

test_data = DataPairs(observation, ground_truth, name='shepp-logan + pois')

# %% task table and reconstructors
eval_tt = TaskTable()

fbp_reconstructor = FBPReconstructor(ray_trafo,
                                     hyper_params={
                                         'filter_type': 'Hann',
                                         'frequency_scaling': 0.8
                                     })
cg_reconstructor = CGReconstructor(ray_trafo, reco_space.zero(), 4)
gn_reconstructor = GaussNewtonReconstructor(ray_trafo, reco_space.zero(), 2)
lw_reconstructor = LandweberReconstructor(ray_trafo, reco_space.zero(), 8)
mlem_reconstructor = MLEMReconstructor(ray_trafo, 0.5 * reco_space.one(), 1)

reconstructors = [
Ejemplo n.º 8
0
    def get_data_pairs_per_index(self, part='train', index=None):
        """
        Return specific samples from data part as :class:`.DataPairs` object.

        Only supports datasets with two elements per sample.

        For datasets not supporting random access, samples are extracted from
        :meth:`generator`, which can be computationally expensive.

        Parameters
        ----------
        part : {``'train'``, ``'validation'``, ``'test'``}, optional
            The data part. Default is ``'train'``.
        index : int or list of int, optional
            Indices of the samples in the data part. Default is ``'[0]'``.
        """
        if self.get_num_elements_per_sample() != 2:
            raise ValueError(
                '`get_data_pairs` only supports datasets with'
                '2 elements per sample, this dataset has {:d}'.format(
                    self.get_num_elements_per_sample()))
        if index is None:
            index = [0]

        if not isinstance(index, list) and not isinstance(index, int):
            raise ValueError('`index` must be an integer or a list of '
                             'integer elements')
        elif isinstance(index, int):
            index = [index]

        name = '{} part: index{}'.format(part, index)

        if len(index) == 0:
            data_pairs = DataPairs([], [], name=name)
            return data_pairs

        if not (min(index) >= 0 and max(index) <= self.get_len(part) - 1):
            raise ValueError('index out of bounds. All indices must be '
                             'between 0 and {} (inclusively).'.format(
                                 self.get_len(part) - 1))

        if self.supports_random_access():
            observations, ground_truth = [], []
            for current_index in index:
                obs, gt = self.get_sample(current_index, part=part)
                observations.append(obs)
                ground_truth.append(gt)
        else:
            gen = self.generator(part=part)
            observations = [None] * len(index)
            ground_truth = [None] * len(index)
            argsort_index = np.argsort(index)
            c = 0
            current_index = index[argsort_index[0]]

            for i, (obs, gt) in enumerate(gen):
                while i == current_index:
                    observations[argsort_index[c]] = obs
                    ground_truth[argsort_index[c]] = gt
                    c += 1
                    if c == len(index):
                        break
                    current_index = index[argsort_index[c]]
                if c == len(index):
                    break

        name = '{} part: index{}'.format(part, index)
        data_pairs = DataPairs(observations, ground_truth, name=name)
        return data_pairs
Ejemplo n.º 9
0
from dival.util.odl_utility import uniform_discr_element
from dival.data import DataPairs
from dival.evaluation import TaskTable
from dival.measure import L2
from dival import Reconstructor
import numpy as np

np.random.seed(1)

ground_truth = uniform_discr_element([0, 1, 2, 3, 4, 5, 6])
observation = ground_truth + 1
observation += np.random.normal(size=observation.shape)
test_data = DataPairs(observation, ground_truth, name='x + 1 + normal')
eval_tt = TaskTable()


class MinusOneReconstructor(Reconstructor):
    def reconstruct(self, observation):
        return observation - 1


reconstructor = MinusOneReconstructor(name='y-1')
eval_tt.append(reconstructor=reconstructor, test_data=test_data,
               measures=[L2])
results = eval_tt.run()
results.plot_reconstruction(0)
print(results)