コード例 #1
0
ファイル: test_evaluation.py プロジェクト: gengmufeng/dival
    def test_option_save_iterates(self):
        domain = odl.uniform_discr([0, 0], [1, 1], (1, 1))
        ground_truth = domain.one()
        observation = domain.one()

        # reconstruct 1., iterates 0., 0.5, 0.75, 0.875, ...

        class DummyReconstructor(StandardIterativeReconstructor):
            def _setup(self, observation):
                self.setup_var = 'dummy_val'

            def _compute_iterate(self, observation, reco_previous, out):
                out[:] = 0.5 * (observation + reco_previous)

        test_data = DataPairs([observation], [ground_truth])
        tt = TaskTable()
        r = DummyReconstructor(reco_space=domain)
        hyper_param_choices = {'iterations': [10]}
        options = {'save_iterates': True}
        tt.append(r, test_data, hyper_param_choices=hyper_param_choices,
                  options=options)

        results = tt.run()
        self.assertAlmostEqual(
            1., results.results['misc'][0, 0]['iterates'][0][2][0, 0],
            delta=0.2)
        self.assertNotAlmostEqual(
            1., results.results['misc'][0, 0]['iterates'][0][1][0, 0],
            delta=0.2)
コード例 #2
0
ファイル: test_evaluation.py プロジェクト: magicknight/dival
    def test_option_save_best_reconstructor(self):
        dataset = get_standard_dataset('ellipses', impl='skimage')
        test_data = dataset.get_data_pairs('validation', 1)
        tt = TaskTable()
        fbp_reconstructor = FBPReconstructor(dataset.ray_trafo)
        hyper_param_choices = {'filter_type': ['Ram-Lak', 'Hann'],
                               'frequency_scaling': [0.1, 0.5, 1.0]}
        known_best_choice = {'filter_type': 'Hann',
                             'frequency_scaling': 0.5}
        path = 'dummypath'
        options = {'save_best_reconstructor': {'path': path,
                                               'measure': PSNR}}
        tt.append(fbp_reconstructor, test_data, measures=[PSNR],
                  hyper_param_choices=hyper_param_choices,
                  options=options)

        class ExtStringIO(StringIO):
            def __init__(self, ext, f, *args, **kwargs):
                self.ext = ext
                self.f = f
                super().__init__(*args, **kwargs)
                self.ext[self.f] = self.getvalue()

            def close(self):
                self.ext[self.f] = self.getvalue()
                super().close()
        ext = {}
        with patch('dival.reconstructors.reconstructor.open',
                   lambda f, *a, **kw: ExtStringIO(ext, f)):
            tt.run()
        self.assertIn(path + '_hyper_params.json', ext)
        self.assertDictEqual(json.loads(ext[path + '_hyper_params.json']),
                             known_best_choice)
コード例 #3
0
    def test_option_save_best_reconstructor_reuse_iterates(self):
        # test 'save_best_reconstructor' option together with 'iterations' in
        # hyper_param_choices, because `run` has performance optimization for
        # it (with the default argument ``reuse_iterates=True``)
        domain = odl.uniform_discr([0, 0], [1, 1], (2, 2))
        ground_truth = domain.element([[1, 0], [0, 1]])
        observation = domain.element([[0, 0], [0, 0]])

        # Reconstruct [[1, 0], [0, 1]], iterates are
        # [[0, 0], [0, 0]], [[.1, .1], [.1, .1]], [[.2, .2], [.2, .2]], ...
        # Best will be [[.5, .5], [.5, .5]].

        class DummyReconstructor(StandardIterativeReconstructor):
            def _setup(self, observation):
                self.setup_var = 'dummy_val'

            def _compute_iterate(self, observation, reco_previous, out):
                out[:] = reco_previous + 0.1

        test_data = DataPairs([observation], [ground_truth])
        tt = TaskTable()
        r = DummyReconstructor(reco_space=domain)
        hyper_param_choices = {'iterations': list(range(10))}
        known_best_choice = {'iterations': 5}
        path = 'dummypath'
        options = {'save_best_reconstructor': {'path': path, 'measure': PSNR}}
        tt.append(r,
                  test_data,
                  measures=[PSNR],
                  hyper_param_choices=hyper_param_choices,
                  options=options)

        class ExtStringIO(StringIO):
            def __init__(self, ext, f, *args, **kwargs):
                self.ext = ext
                self.f = f
                super().__init__(*args, **kwargs)
                self.ext[self.f] = self.getvalue()

            def close(self):
                self.ext[self.f] = self.getvalue()
                super().close()

        ext = {}
        with patch('dival.reconstructors.reconstructor.open',
                   lambda f, *a, **kw: ExtStringIO(ext, f)):
            tt.run()
        self.assertIn(path + '_hyper_params.json', ext)
        self.assertDictEqual(json.loads(ext[path + '_hyper_params.json']),
                             known_best_choice)
コード例 #4
0
    def test_iterations_hyper_param_choices(self):
        # test 'iterations' in hyper_param_choices, because `run` has
        # performance optimization for it
        domain = odl.uniform_discr([0, 0], [1, 1], (1, 1))
        ground_truth = domain.one()
        observation = domain.one()

        # reconstruct 1., iterates 0., 0.5, 0.75, 0.875, ...

        class DummyReconstructor(StandardIterativeReconstructor):
            def _setup(self, observation):
                self.setup_var = 'dummy_val'

            def _compute_iterate(self, observation, reco_previous, out):
                out[:] = 0.5 * (observation + reco_previous)

        test_data = DataPairs([observation], [ground_truth])
        tt = TaskTable()
        r = DummyReconstructor(reco_space=domain)
        hyper_param_choices = {'iterations': [2, 3, 10]}
        tt.append(r, test_data, hyper_param_choices=hyper_param_choices)

        iters = []
        r.callback = CallbackStore(iters)
        results = tt.run(reuse_iterates=True)
        self.assertAlmostEqual(1.,
                               results.results['reconstructions'][0, 1][0][0,
                                                                           0],
                               delta=0.2)
        self.assertNotAlmostEqual(1.,
                                  results.results['reconstructions'][0,
                                                                     0][0][0,
                                                                           0],
                                  delta=0.2)
        self.assertEqual(len(iters), max(hyper_param_choices['iterations']))
        print(results.results['misc'])

        iters2 = []
        r.callback = CallbackStore(iters2)
        results2 = tt.run(reuse_iterates=False)
        self.assertAlmostEqual(1.,
                               results2.results['reconstructions'][0, 1][0][0,
                                                                            0],
                               delta=0.2)
        self.assertNotAlmostEqual(1.,
                                  results2.results['reconstructions'][0,
                                                                      0][0][0,
                                                                            0],
                                  delta=0.2)
        self.assertEqual(len(iters2), sum(hyper_param_choices['iterations']))
コード例 #5
0
ファイル: test_evaluation.py プロジェクト: magicknight/dival
 def test(self):
     reco_space = odl.uniform_discr(
         min_pt=[-64, -64], max_pt=[64, 64], shape=[128, 128])
     phantom = odl.phantom.shepp_logan(reco_space, modified=True)
     geometry = odl.tomo.parallel_beam_geometry(reco_space, 30)
     ray_trafo = odl.tomo.RayTransform(reco_space, geometry, impl='skimage')
     proj_data = ray_trafo(phantom)
     observation = np.asarray(
         proj_data +
         np.random.normal(loc=0., scale=2., size=proj_data.shape))
     test_data = DataPairs(observation, phantom)
     tt = TaskTable()
     fbp_reconstructor = FBPReconstructor(ray_trafo, hyper_params={
         'filter_type': 'Hann',
         'frequency_scaling': 0.8})
     tt.append(fbp_reconstructor, test_data, measures=[PSNR, SSIM])
     tt.run()
     self.assertGreater(
         tt.results.results['measure_values'][0, 0]['psnr'][0], 15.)
コード例 #6
0
ファイル: ct_example.py プロジェクト: pscoutosoares/dival
reco_space = odl.uniform_discr(min_pt=[-20, -20],
                               max_pt=[20, 20],
                               shape=[300, 300],
                               dtype='float32')
phantom = odl.phantom.shepp_logan(reco_space, modified=True)
ground_truth = phantom

geometry = odl.tomo.cone_beam_geometry(reco_space, 40, 40, 360)
ray_trafo = odl.tomo.RayTransform(reco_space, geometry, impl='astra_cpu')
proj_data = ray_trafo(phantom)
observation = (proj_data + np.random.poisson(0.3, proj_data.shape)).asarray()

test_data = DataPairs(observation, ground_truth, name='shepp-logan + pois')

# %% task table and reconstructors
eval_tt = TaskTable()

fbp_reconstructor = FBPReconstructor(ray_trafo,
                                     hyper_params={
                                         'filter_type': 'Hann',
                                         'frequency_scaling': 0.8
                                     })
cg_reconstructor = CGReconstructor(ray_trafo, reco_space.zero(), 4)
gn_reconstructor = GaussNewtonReconstructor(ray_trafo, reco_space.zero(), 2)
lw_reconstructor = LandweberReconstructor(ray_trafo, reco_space.zero(), 8)
mlem_reconstructor = MLEMReconstructor(ray_trafo, 0.5 * reco_space.one(), 1)

reconstructors = [
    fbp_reconstructor, cg_reconstructor, gn_reconstructor, lw_reconstructor,
    mlem_reconstructor
]
コード例 #7
0
import numpy as np
from dival.evaluation import TaskTable
from dival.measure import PSNR, SSIM
from dival.reconstructors.odl_reconstructors import FBPReconstructor
from dival.datasets.standard import get_standard_dataset

np.random.seed(0)

# %% data
dataset = get_standard_dataset('ellipses', impl='astra_cpu')
test_data = dataset.get_data_pairs('test', 10)

# %% task table and reconstructors
eval_tt = TaskTable()

reconstructor = FBPReconstructor(dataset.ray_trafo)

eval_tt.append(reconstructor=reconstructor, measures=[PSNR, SSIM],
               test_data=test_data,
               hyper_param_choices={'filter_type': ['Ram-Lak', 'Hann'],
                                    'frequency_scaling': [0.8, 0.9, 1.]})

# %% run task table
results = eval_tt.run()
print(results.to_string(show_columns=['misc']))

# %% plot reconstructions
fig = results.plot_all_reconstructions(test_ind=range(1),
                                       fig_size=(9, 4), vrange='individual')
コード例 #8
0
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # TODO adjust

# %% data
dataset = get_standard_dataset('lodopab')
dataset.fbp_dataset = CachedFBPDataset(
    dataset, {
        'train': '/localdata/jleuschn/lodopab/reco_fbps_train.npy',
        'validation': '/localdata/jleuschn/lodopab/reco_fbps_validation.npy',
        'test': '/localdata/jleuschn/lodopab/reco_fbps_test.npy'
    })
ray_trafo = dataset.get_ray_trafo(impl='astra_cpu')
reco_space = ray_trafo.domain
test_data = dataset.get_data_pairs('test', 7)

# %% task table and reconstructors
eval_tt = TaskTable()

fbp_reconstructor = FBPReconstructor(ray_trafo,
                                     hyper_params={
                                         'filter_type': 'Hann',
                                         'frequency_scaling': 0.8
                                     })

fbp_unet_reconstructor = FBPUNetReconstructor(ray_trafo,
                                              batch_size=64,
                                              use_cuda=True)
state_filename = 'fbp_unet_reconstructor_lodopab_baseline_state.pt'
with open(state_filename, 'wb') as file:
    r = requests.get('https://github.com/jleuschn/supp.dival/raw/master/'
                     'examples/'
                     'fbp_unet_reconstructor_lodopab_baseline_state.pt')
コード例 #9
0
#                       dataset=dataset,
#                       HYPER_PARAMS_override={
#                           'l2_regularization': {
#                               'method': 'grid_search',
#                               'range': [0., 2.],
#                               'grid_search_options': {
#                                   'type': 'logarithmic',
#                                   'num_samples': 10
#                               }
#                           }})
# =============================================================================

print('optimized l2 reg. coeff.: {}'.format(
    reconstructor.hyper_params['l2_regularization']))

# %% task table
eval_tt = TaskTable()
eval_tt.append(reconstructor=reconstructor,
               test_data=test_data,
               dataset=dataset,
               measures=[L2, PSNR])

results = eval_tt.run()
print(results)

# %% plot reconstructions
fig = results.plot_reconstruction(0,
                                  test_ind=range(3),
                                  fig_size=(9, 4),
                                  vrange='individual')
コード例 #10
0
import numpy as np
from dival.evaluation import TaskTable
from dival.measure import L2
from dival.reconstructors.odl_reconstructors import FBPReconstructor
from dival.datasets.standard import get_standard_dataset

np.random.seed(0)

# %% data
dataset = get_standard_dataset('ellipses')
validation_data = dataset.get_data_pairs('validation', 10)
test_data = dataset.get_data_pairs('test', 10)

# %% task table and reconstructors
eval_tt = TaskTable()

reconstructor = FBPReconstructor(dataset.ray_trafo)
options = {
    'hyper_param_search': {
        'measure': L2,
        'validation_data': validation_data
    }
}

eval_tt.append(reconstructor=reconstructor,
               test_data=test_data,
               options=options)

# %% run task table
results = eval_tt.run()
print(results.to_string(formatters={'reconstructor': lambda r: r.name}))