Exemple #1
0
def main():
    # load data
    options = get_parser().parse_args()

    dataset = get_standard_dataset('lodopab')
    test_data = dataset.get_data_pairs('validation')
    ray_trafo = dataset.ray_trafo

    reduced_dataset = RandomSampleDataset(dataset,
                                          size_part=0.1,
                                          seed=options.seed)

    reconstructor = LearnedPDReconstructor(ray_trafo=ray_trafo, num_workers=8)
    reconstructor.load_hyper_params('params')

    reconstructor.save_best_learned_params_path = 'best-model-{}'.format(
        options.seed)
    reconstructor.log_dir = options.log_dir

    # create a Dival task table and run it
    task_table = TaskTable()
    task_table.append(reconstructor=reconstructor,
                      measures=[PSNR, SSIM],
                      test_data=test_data,
                      dataset=reduced_dataset,
                      hyper_param_choices=[reconstructor.hyper_params])
    results = task_table.run()

    # save report
    save_results_table(results, full_name)
Exemple #2
0
    def test_option_save_best_reconstructor(self):
        dataset = get_standard_dataset('ellipses', impl='skimage')
        test_data = dataset.get_data_pairs('validation', 1)
        tt = TaskTable()
        fbp_reconstructor = FBPReconstructor(dataset.ray_trafo)
        hyper_param_choices = {'filter_type': ['Ram-Lak', 'Hann'],
                               'frequency_scaling': [0.1, 0.5, 1.0]}
        known_best_choice = {'filter_type': 'Hann',
                             'frequency_scaling': 0.5}
        path = 'dummypath'
        options = {'save_best_reconstructor': {'path': path,
                                               'measure': PSNR}}
        tt.append(fbp_reconstructor, test_data, measures=[PSNR],
                  hyper_param_choices=hyper_param_choices,
                  options=options)

        class ExtStringIO(StringIO):
            def __init__(self, ext, f, *args, **kwargs):
                self.ext = ext
                self.f = f
                super().__init__(*args, **kwargs)
                self.ext[self.f] = self.getvalue()

            def close(self):
                self.ext[self.f] = self.getvalue()
                super().close()
        ext = {}
        with patch('dival.reconstructors.reconstructor.open',
                   lambda f, *a, **kw: ExtStringIO(ext, f)):
            tt.run()
        self.assertIn(path + '_hyper_params.json', ext)
        self.assertDictEqual(json.loads(ext[path + '_hyper_params.json']),
                             known_best_choice)
Exemple #3
0
 def testGradient(self):
     dataset = get_standard_dataset('ellipses',
                                    fixed_seeds=True,
                                    impl='astra_cuda')
     ray_trafo = dataset.get_ray_trafo(impl='astra_cuda')
     module = TorchRayTrafoParallel2DModule(ray_trafo)
     batch_size, channels = 2, 3
     torch_in_ = torch.ones((batch_size * channels, ) + dataset.shape[1],
                            requires_grad=True)
     torch_in = torch_in_.view(batch_size, channels, *dataset.shape[1])
     torch_out = module(torch_in).view(-1, *dataset.shape[0])
     for i in range(batch_size * channels):
         for j in range(0, dataset.shape[0][0], 3):  # angle
             for k in range(0, dataset.shape[0][1], 12):  # detector pos
                 odl_value_in = np.zeros(dataset.shape[0])
                 odl_value_in[j, k] = 1.
                 odl_grad = ray_trafo.adjoint(odl_value_in)
                 torch_in_.grad = None
                 torch_out[i, j, k].backward(retain_graph=True)
                 torch_grad_np = (torch_in_.grad[i].detach().cpu().numpy())
                 # very rough check for maximum error
                 self.assertTrue(
                     np.allclose(torch_grad_np, odl_grad, rtol=1.))
                 non_zero = np.nonzero(np.asarray(odl_grad))
                 if np.any(odl_grad):  # there seem to be cases where
                     # a pixel has no influence on the
                     # gradient
                     # tighter check for mean error
                     self.assertLess(
                         np.mean(
                             np.abs(torch_grad_np[non_zero] -
                                    odl_grad[non_zero]) /
                             np.abs(odl_grad[non_zero])), 1e-3)
def get_dataloaders_ct(batch_size=1, distributed_bool = False, num_workers = 0, IMPL = 'astra_cuda', cache_dir = path.join(dataset_dir, 'cache_lodopab'), include_validation = True, **kwargs):
    
    if include_validation:
        parts = ['train', 'validation', 'test']
        batch_sizes = {'train': batch_size,'validation': 1, 'test':1 }

    else:
        parts = ['train', 'test']
        batch_sizes = {'train': batch_size, 'test':1 }
                  
    CACHE_FILES = {part: (path.join(cache_dir, 'cache_lodopab_' + part + '_fbp.npy'), None) for part in parts }
    
    standard_dataset = get_standard_dataset('lodopab', impl=IMPL)
    ray_trafo = standard_dataset.get_ray_trafo(impl=IMPL)
    dataset = get_cached_fbp_dataset(standard_dataset, ray_trafo, CACHE_FILES)
        
   
    # create PyTorch datasets        
    datasets = {x: RandomAccessTorchDataset(dataset = dataset,
        part =  x, reshape=((1,) + dataset.space[0].shape,
                               (1,) + dataset.space[1].shape)) for x in parts}
    
    if distributed_bool == True:
        dataloaders = {x: DataLoader(datasets[x], batch_size=batch_sizes[x], num_workers = num_workers, worker_init_fn = worker_init_fn, pin_memory=True, sampler=DistributedSampler(datasets[x]) ) for x in parts}
    else:
        dataloaders = {x: DataLoader(datasets[x], batch_size=batch_sizes[x], shuffle=(x == 'train'), worker_init_fn = worker_init_fn, pin_memory=True, num_workers = num_workers ) for x in parts}

    return dataloaders
def construct_reconstructor(reconstructor_key_name_or_type, dataset_name,
                            **kwargs):
    """
    Construct reference reconstructor object (not loading parameters).

    Note: see :func:get_reference_reconstructor to retrieve a reference
    reconstructor with optimized parameters.

    This function implements the constructors calls which are potentially
    specific to each configuration.

    Parameters
    ----------
    reconstructor_key_name_or_type : str or type
        Key name of configuration or reconstructor type.
    dataset_name : str
        Standard dataset name.
    kwargs : dict
        Keyword arguments.
        For CT configurations this includes the ``'impl'`` used by
        :class:`odl.tomo.RayTransform`.

    Raises
    ------
    ValueError
        If the configuration does not exist.
    NotImplementedError
        If construction is not implemented for the configuration.

    Returns
    -------
    reconstructor : :class:`Reconstructor`
        The reconstructor instance.
    """
    r_key_name, r_type = validate_reconstructor_key_name_or_type(
        reconstructor_key_name_or_type, dataset_name)
    r_args = []
    r_kwargs = {}
    if dataset_name in ['ellipses', 'lodopab']:
        impl = kwargs.pop('impl', 'astra_cuda')
        dataset = get_standard_dataset(dataset_name, impl=impl)
        if r_key_name in ['fbp', 'fbpunet', 'iradonmap', 'learnedgd',
                          'learnedpd', 'tvadam', 'diptv']:
            ray_trafo = dataset.get_ray_trafo(impl=impl)
            r_args = [ray_trafo]
            r_kwargs['name'] = '{d}_{r}'.format(r=r_key_name, d=dataset_name)
        else:
            raise NotImplementedError(
                'reconstructor construction is not implemented for reference '
                'configuration \'{}\' for dataset \'{}\''
                .format(r_key_name, dataset_name))
    else:
        raise NotImplementedError(
            'reference reconstructor construction is not implemented for '
            'dataset \'{}\''.format(dataset_name))
    reconstructor = r_type(*r_args, **r_kwargs)
    return reconstructor
Exemple #6
0
 def test_generator(self):
     d = get_standard_dataset('ellipses', fixed_seeds=True, impl='skimage')
     angle_indices = range(0, d.shape[0][0], 2)
     asd = AngleSubsetDataset(d, angle_indices)
     test_data_asd = asd.get_data_pairs('train', 3)
     test_data = d.get_data_pairs('train', 3)
     for (obs_asd, gt_asd), (obs, gt) in zip(test_data_asd, test_data):
         obs_subset = np.asarray(obs)[np.asarray(angle_indices), :]
         self.assertEqual(obs_asd.shape, obs_subset.shape)
         self.assertEqual(gt_asd.shape, gt.shape)
         self.assertTrue(np.all(np.asarray(obs_asd) == obs_subset))
         self.assertTrue(np.all(np.asarray(gt_asd) == np.asarray(gt)))
Exemple #7
0
 def test_get_ray_trafo(self):
     d = get_standard_dataset('ellipses', fixed_seeds=True, impl='skimage')
     for angle_indices in (
             range(0, d.shape[0][0], 2),
             range(0, d.shape[0][0] // 2),
             range(d.shape[0][0] // 2, d.shape[0][0]),
             np.concatenate(
                 [np.arange(0, int(d.shape[0][0] * 1/4)),
                  np.arange(int(d.shape[0][0] * 3/4), d.shape[0][0])])):
         asd = AngleSubsetDataset(d, angle_indices)
         ray_trafo = asd.get_ray_trafo(impl='skimage')
         self.assertEqual(ray_trafo.range.shape[0], len(angle_indices))
         angles_subset = d.get_ray_trafo(impl='skimage').geometry.angles[
             np.asarray(angle_indices)]
         self.assertEqual(ray_trafo.geometry.angles.shape, angles_subset.shape)
         self.assertTrue(np.all(ray_trafo.geometry.angles == angles_subset))
Exemple #8
0
    def prepare_data(self, *args, **kwargs):
        lodopab = dival.get_standard_dataset('lodopab', impl='astra_cpu')
        assert self.gt_shape <= self.IMG_SHAPE, 'GT is larger than original images.'
        if self.gt_shape < self.IMG_SHAPE:
            crop_off = (362 - self.gt_shape) // 2
            gt_train = np.array([lodopab.get_sample(i, part='train', out=(False, True))[1][crop_off:-(crop_off + 1),
                                 crop_off:-(crop_off + 1)] for i in
                                 range(4000)])
            gt_val = np.array([lodopab.get_sample(i, part='validation', out=(False, True))[1][crop_off:-(crop_off + 1),
                               crop_off:-(crop_off + 1)] for i in
                               range(400)])
            gt_test = np.array([lodopab.get_sample(i, part='test', out=(False, True))[1][crop_off:-(crop_off + 1),
                                crop_off:-(crop_off + 1)] for i in
                                range(3553)])
        else:
            gt_train = np.array(
                [lodopab.get_sample(i, part='train', out=(False, True))[1][1:, 1:] for i in range(4000)])
            gt_val = np.array(
                [lodopab.get_sample(i, part='validation', out=(False, True))[1][1:, 1:] for i in range(400)])
            gt_test = np.array(
                [lodopab.get_sample(i, part='test', out=(False, True))[1][1:, 1:] for i in range(3553)])

        gt_train = torch.from_numpy(gt_train)
        gt_val = torch.from_numpy(gt_val)
        gt_test = torch.from_numpy(gt_test)

        assert gt_train.shape[1] == self.gt_shape
        assert gt_train.shape[2] == self.gt_shape

        self.mean = gt_train.mean()
        self.std = gt_train.std()

        gt_train = normalize(gt_train, self.mean, self.std)
        gt_val = normalize(gt_val, self.mean, self.std)
        gt_test = normalize(gt_test, self.mean, self.std)

        circle = self.__get_circle__()
        gt_train *= circle
        gt_val *= circle
        gt_test *= circle

        ds_factory = GroundTruthDatasetFactory(gt_train, gt_val, gt_test, inner_circle=self.inner_circle)
        self.gt_ds = ds_factory.build_projection_dataset(num_angles=self.num_angles,
                                                         upscale_shape=self.gt_shape + (self.gt_shape // 2 - 7),
                                                         impl='astra_cpu')
Exemple #9
0
 def test(self):
     dataset = get_standard_dataset('ellipses',
                                    fixed_seeds=True,
                                    impl='astra_cuda')
     ray_trafo = dataset.get_ray_trafo(impl='astra_cuda')
     module = TorchRayTrafoParallel2DAdjointModule(ray_trafo)
     for batch_size, channels in [(1, 3), (5, 1), (2, 3)]:
         test_data = dataset.get_data_pairs(part='train',
                                            n=batch_size * channels)
         torch_in = (torch.from_numpy(np.asarray(
             test_data.observations)).view(batch_size, channels,
                                           *dataset.shape[0]))
         torch_out = module(torch_in).view(-1, *dataset.shape[1])
         for i, odl_in in enumerate(test_data.observations):
             odl_out = ray_trafo.adjoint(odl_in)
             self.assertTrue(
                 np.allclose(torch_out[i].detach().cpu().numpy(),
                             odl_out,
                             rtol=1.e-4))
def load_standard_dataset(dataset, impl=None, ordered=False):
    """
    Loads a Dival standard dataset.
    :param dataset: Name of the standard dataset
    :param impl: Backend for the Ray Transform
    :param ordered: Whether to order by patient id for 'lodopab' dataset
    :param angle_indices: Indices of the angles to include (default is all).
    :return: Dival dataset.
    """
    if impl is None:
        impl = 'astra_cpu'
        if torch.cuda.is_available():
            impl = 'astra_cuda'
    kwargs = {'impl': impl}
    if dataset == 'ellipses':
        kwargs['fixed_seeds'] = True
    # we do not use 'sorted_by_patient' here in order to be transparent to
    # `CachedDataset`, where a `ReorderedDataset` is handled specially
    # if dataset == 'lodopab':
    # kwargs['sorted_by_patient'] = ordered

    dataset_name = dataset.split('_')[0]
    dataset_out = get_standard_dataset(dataset_name, **kwargs)

    if dataset == 'lodopab_200':
        angles = list(range(0, 1000, 5))
        dataset_out = AngleSubsetDataset(dataset_out, angles)

    if dataset_name == 'lodopab' and ordered:
        idx = get_lodopab_idx_sorted_by_patient()
        dataset_ordered = ReorderedDataset(dataset_out, idx)
        dataset_ordered.ray_trafo = dataset_out.ray_trafo
        dataset_ordered.get_ray_trafo = dataset_out.get_ray_trafo
        dataset_out = dataset_ordered

    return dataset_out
Exemple #11
0
                                        get_cached_fbp_dataset)
from dival.reference_reconstructors import (check_for_params, download_params,
                                            get_hyper_params_path)
from dival.util.plot import plot_images

IMPL = 'astra_cuda'

LOG_DIR = './logs/lodopab_fbpunet'
SAVE_BEST_LEARNED_PARAMS_PATH = './params/lodopab_fbpunet'

CACHE_FILES = {
    'train': ('./cache_lodopab_train_fbp.npy', None),
    'validation': ('./cache_lodopab_validation_fbp.npy', None)
}

dataset = get_standard_dataset('lodopab', impl=IMPL)
ray_trafo = dataset.get_ray_trafo(impl=IMPL)
test_data = dataset.get_data_pairs('test', 100)

reconstructor = FBPUNetReconstructor(
    ray_trafo,
    log_dir=LOG_DIR,
    save_best_learned_params_path=SAVE_BEST_LEARNED_PARAMS_PATH)

#%% obtain reference hyper parameters
if not check_for_params('fbpunet', 'lodopab', include_learned=False):
    download_params('fbpunet', 'lodopab', include_learned=False)
hyper_params_path = get_hyper_params_path('fbpunet', 'lodopab')
reconstructor.load_hyper_params(hyper_params_path)

#%% expose FBP cache to reconstructor by assigning `fbp_dataset` attribute
import numpy as np
import matplotlib.pyplot as plt
from dival import get_standard_dataset
from dival.evaluation import TaskTable
from dival.measure import PSNR, SSIM
from dival.reference_reconstructors import get_reference_reconstructor

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

IMPL = 'astra_cuda'
DATASET = 'ellipses'

np.random.seed(0)

# %% data
dataset = get_standard_dataset(DATASET, impl=IMPL)
ray_trafo = dataset.get_ray_trafo(impl=IMPL)
reco_space = ray_trafo.domain
test_data = dataset.get_data_pairs('test', 100)

# %% task table and reconstructors
eval_tt = TaskTable()

fbp_reconstructor = get_reference_reconstructor('fbp', DATASET)
tvadam_reconstructor = get_reference_reconstructor('tvadam', DATASET)
fbpunet_reconstructor = get_reference_reconstructor('fbpunet', DATASET)
iradonmap_reconstructor = get_reference_reconstructor('iradonmap', DATASET)
learnedgd_reconstructor = get_reference_reconstructor('learnedgd', DATASET)
learnedpd_reconstructor = get_reference_reconstructor('learnedpd', DATASET)

reconstructors = [
from dival.measure import PSNR

from dival.datasets.fbp_dataset import (generate_fbp_cache_files,
                                        get_cached_fbp_dataset)
from dival.reference_reconstructors import (check_for_params, download_params,
                                            get_hyper_params_path)
from dival.util.plot import plot_images

from self_supervised_ct.n2self import N2SelfReconstructor

IMPL = 'astra_cpu'

LOG_DIR = './logs/ellipses_n2self'
SAVE_BEST_LEARNED_PARAMS_PATH = './params/ellipses_n2self'

dataset = get_standard_dataset('ellipses', impl=IMPL)
ray_trafo = dataset.get_ray_trafo(impl=IMPL)
test_data = dataset.get_data_pairs('test', 100)

reconstructor = N2SelfReconstructor(
    ray_trafo,
    log_dir=LOG_DIR,
    save_best_learned_params_path=SAVE_BEST_LEARNED_PARAMS_PATH)

#%% obtain reference hyper parameters
if not check_for_params('fbpunet', 'ellipses', include_learned=False):
    download_params('fbpunet', 'ellipses', include_learned=False)
hyper_params_path = get_hyper_params_path('fbpunet', 'ellipses')
reconstructor.load_hyper_params(hyper_params_path)
print(reconstructor.HYPER_PARAMS)
# reconstructor.lr = 0.0001
Exemple #14
0
parser = argparse.ArgumentParser(description='Primal Dual stuff')
parser.add_argument('model_type', default='classic')
parser.add_argument('-restore', type=int, default=None)
parser.add_argument('-lr', type=int, default=4)
parser.add_argument('--test', action='store_true', default=False)
args = parser.parse_args()

channels_in = 1
channels_out = 1
batch_size = 3
msd_width = 1
msd_depth = 20
msd_dilations = [1, 2, 3, 4, 5]
lr = 10**(-args.lr)

dataset = get_standard_dataset('lodopab')


class UnsqueezingDataset(torch.utils.data.Dataset):
    """Quickfix for squeezing in extra dim in dataset retrieved from
    dival's `get_standard_dataset`.
    """
    def __init__(self, ds):
        self._ds = ds

    def __getitem__(self, item):
        sino = torch.unsqueeze(self._ds[item][0], dim=0)
        recon = torch.unsqueeze(self._ds[item][1], dim=0)
        return (sino, recon)

    def __add__(self, other):
Exemple #15
0
    dataset.dataset.dataset.forward_op = OperatorComp(
        odl.tomo.RayTransform(
            ray_trafo.domain, ray_trafo.geometry, impl=ray_trafo.impl),
        dataset.dataset.dataset.forward_op.right)


if __name__ == '__main__':
    try:
        set_start_method('spawn')
    except RuntimeError:
        if get_start_method() != 'spawn':
            raise RuntimeError(
                'Could not set multiprocessing start method to \'spawn\'')

    dataset = get_standard_dataset('ellipses',
                                    fixed_seeds=False,
                                    fixed_noise_seeds=False,
                                    impl=IMPL)
    ray_trafo = dataset.get_ray_trafo(impl=IMPL)
    patch_ray_trafo_for_pickling(ray_trafo)
    test_data = dataset.get_data_pairs('test', 100)
    
    reconstructor = FBPUNetReconstructor(
        ray_trafo, log_dir=LOG_DIR,
        save_best_learned_params_path=SAVE_BEST_LEARNED_PARAMS_PATH,
        allow_multiple_workers_without_random_access=True,
        num_data_loader_workers=2,  # more workers lead to increased VRAM usage
        worker_init_fn=worker_init_fn,  # for recreating forward_op
        )
    
    #%% obtain reference hyper parameters
    if not check_for_params('fbpunet', 'ellipses', include_learned=False):
Exemple #16
0
import numpy as np
from dival import get_standard_dataset

# NOTE: in order to run this test there are some requirements:
# - dival library installed
# - public lodopab dataset downloaded and configured with dival
# - lodopab challenge set observations downloaded (path can be adjusted below)

from lodopab_challenge.challenge_set import (
    config, NUM_ANGLES, NUM_DET_PIXELS, MU_MAX, get_observation,
    get_observations, generator, transform_to_pre_log,
    replace_min_photon_count)

# config['data_path'] = '/localdata/lodopab_challenge_set'

lodopab = get_standard_dataset('lodopab', impl='skimage')

class TestGetObservation(unittest.TestCase):
    def test(self):
        n = 3
        for i, obs2 in zip(range(n), generator()):
            obs = get_observation(i)
            self.assertEqual(obs.shape, (NUM_ANGLES, NUM_DET_PIXELS))
            self.assertEqual(obs.dtype, np.float32)
            self.assertTrue(np.all(obs == obs2))
            obs = np.zeros((NUM_ANGLES, NUM_DET_PIXELS), dtype=np.float32)
            obs_ = get_observation(i, out=obs)
            self.assertIs(obs_, obs)
            self.assertTrue(np.all(obs == obs2))
            obs = lodopab.space[0].zero()
            obs_ = get_observation(i, out=obs)