def main(): # load data options = get_parser().parse_args() dataset = get_standard_dataset('lodopab') test_data = dataset.get_data_pairs('validation') ray_trafo = dataset.ray_trafo reduced_dataset = RandomSampleDataset(dataset, size_part=0.1, seed=options.seed) reconstructor = LearnedPDReconstructor(ray_trafo=ray_trafo, num_workers=8) reconstructor.load_hyper_params('params') reconstructor.save_best_learned_params_path = 'best-model-{}'.format( options.seed) reconstructor.log_dir = options.log_dir # create a Dival task table and run it task_table = TaskTable() task_table.append(reconstructor=reconstructor, measures=[PSNR, SSIM], test_data=test_data, dataset=reduced_dataset, hyper_param_choices=[reconstructor.hyper_params]) results = task_table.run() # save report save_results_table(results, full_name)
def test_option_save_best_reconstructor(self): dataset = get_standard_dataset('ellipses', impl='skimage') test_data = dataset.get_data_pairs('validation', 1) tt = TaskTable() fbp_reconstructor = FBPReconstructor(dataset.ray_trafo) hyper_param_choices = {'filter_type': ['Ram-Lak', 'Hann'], 'frequency_scaling': [0.1, 0.5, 1.0]} known_best_choice = {'filter_type': 'Hann', 'frequency_scaling': 0.5} path = 'dummypath' options = {'save_best_reconstructor': {'path': path, 'measure': PSNR}} tt.append(fbp_reconstructor, test_data, measures=[PSNR], hyper_param_choices=hyper_param_choices, options=options) class ExtStringIO(StringIO): def __init__(self, ext, f, *args, **kwargs): self.ext = ext self.f = f super().__init__(*args, **kwargs) self.ext[self.f] = self.getvalue() def close(self): self.ext[self.f] = self.getvalue() super().close() ext = {} with patch('dival.reconstructors.reconstructor.open', lambda f, *a, **kw: ExtStringIO(ext, f)): tt.run() self.assertIn(path + '_hyper_params.json', ext) self.assertDictEqual(json.loads(ext[path + '_hyper_params.json']), known_best_choice)
def testGradient(self): dataset = get_standard_dataset('ellipses', fixed_seeds=True, impl='astra_cuda') ray_trafo = dataset.get_ray_trafo(impl='astra_cuda') module = TorchRayTrafoParallel2DModule(ray_trafo) batch_size, channels = 2, 3 torch_in_ = torch.ones((batch_size * channels, ) + dataset.shape[1], requires_grad=True) torch_in = torch_in_.view(batch_size, channels, *dataset.shape[1]) torch_out = module(torch_in).view(-1, *dataset.shape[0]) for i in range(batch_size * channels): for j in range(0, dataset.shape[0][0], 3): # angle for k in range(0, dataset.shape[0][1], 12): # detector pos odl_value_in = np.zeros(dataset.shape[0]) odl_value_in[j, k] = 1. odl_grad = ray_trafo.adjoint(odl_value_in) torch_in_.grad = None torch_out[i, j, k].backward(retain_graph=True) torch_grad_np = (torch_in_.grad[i].detach().cpu().numpy()) # very rough check for maximum error self.assertTrue( np.allclose(torch_grad_np, odl_grad, rtol=1.)) non_zero = np.nonzero(np.asarray(odl_grad)) if np.any(odl_grad): # there seem to be cases where # a pixel has no influence on the # gradient # tighter check for mean error self.assertLess( np.mean( np.abs(torch_grad_np[non_zero] - odl_grad[non_zero]) / np.abs(odl_grad[non_zero])), 1e-3)
def get_dataloaders_ct(batch_size=1, distributed_bool = False, num_workers = 0, IMPL = 'astra_cuda', cache_dir = path.join(dataset_dir, 'cache_lodopab'), include_validation = True, **kwargs): if include_validation: parts = ['train', 'validation', 'test'] batch_sizes = {'train': batch_size,'validation': 1, 'test':1 } else: parts = ['train', 'test'] batch_sizes = {'train': batch_size, 'test':1 } CACHE_FILES = {part: (path.join(cache_dir, 'cache_lodopab_' + part + '_fbp.npy'), None) for part in parts } standard_dataset = get_standard_dataset('lodopab', impl=IMPL) ray_trafo = standard_dataset.get_ray_trafo(impl=IMPL) dataset = get_cached_fbp_dataset(standard_dataset, ray_trafo, CACHE_FILES) # create PyTorch datasets datasets = {x: RandomAccessTorchDataset(dataset = dataset, part = x, reshape=((1,) + dataset.space[0].shape, (1,) + dataset.space[1].shape)) for x in parts} if distributed_bool == True: dataloaders = {x: DataLoader(datasets[x], batch_size=batch_sizes[x], num_workers = num_workers, worker_init_fn = worker_init_fn, pin_memory=True, sampler=DistributedSampler(datasets[x]) ) for x in parts} else: dataloaders = {x: DataLoader(datasets[x], batch_size=batch_sizes[x], shuffle=(x == 'train'), worker_init_fn = worker_init_fn, pin_memory=True, num_workers = num_workers ) for x in parts} return dataloaders
def construct_reconstructor(reconstructor_key_name_or_type, dataset_name, **kwargs): """ Construct reference reconstructor object (not loading parameters). Note: see :func:get_reference_reconstructor to retrieve a reference reconstructor with optimized parameters. This function implements the constructors calls which are potentially specific to each configuration. Parameters ---------- reconstructor_key_name_or_type : str or type Key name of configuration or reconstructor type. dataset_name : str Standard dataset name. kwargs : dict Keyword arguments. For CT configurations this includes the ``'impl'`` used by :class:`odl.tomo.RayTransform`. Raises ------ ValueError If the configuration does not exist. NotImplementedError If construction is not implemented for the configuration. Returns ------- reconstructor : :class:`Reconstructor` The reconstructor instance. """ r_key_name, r_type = validate_reconstructor_key_name_or_type( reconstructor_key_name_or_type, dataset_name) r_args = [] r_kwargs = {} if dataset_name in ['ellipses', 'lodopab']: impl = kwargs.pop('impl', 'astra_cuda') dataset = get_standard_dataset(dataset_name, impl=impl) if r_key_name in ['fbp', 'fbpunet', 'iradonmap', 'learnedgd', 'learnedpd', 'tvadam', 'diptv']: ray_trafo = dataset.get_ray_trafo(impl=impl) r_args = [ray_trafo] r_kwargs['name'] = '{d}_{r}'.format(r=r_key_name, d=dataset_name) else: raise NotImplementedError( 'reconstructor construction is not implemented for reference ' 'configuration \'{}\' for dataset \'{}\'' .format(r_key_name, dataset_name)) else: raise NotImplementedError( 'reference reconstructor construction is not implemented for ' 'dataset \'{}\''.format(dataset_name)) reconstructor = r_type(*r_args, **r_kwargs) return reconstructor
def test_generator(self): d = get_standard_dataset('ellipses', fixed_seeds=True, impl='skimage') angle_indices = range(0, d.shape[0][0], 2) asd = AngleSubsetDataset(d, angle_indices) test_data_asd = asd.get_data_pairs('train', 3) test_data = d.get_data_pairs('train', 3) for (obs_asd, gt_asd), (obs, gt) in zip(test_data_asd, test_data): obs_subset = np.asarray(obs)[np.asarray(angle_indices), :] self.assertEqual(obs_asd.shape, obs_subset.shape) self.assertEqual(gt_asd.shape, gt.shape) self.assertTrue(np.all(np.asarray(obs_asd) == obs_subset)) self.assertTrue(np.all(np.asarray(gt_asd) == np.asarray(gt)))
def test_get_ray_trafo(self): d = get_standard_dataset('ellipses', fixed_seeds=True, impl='skimage') for angle_indices in ( range(0, d.shape[0][0], 2), range(0, d.shape[0][0] // 2), range(d.shape[0][0] // 2, d.shape[0][0]), np.concatenate( [np.arange(0, int(d.shape[0][0] * 1/4)), np.arange(int(d.shape[0][0] * 3/4), d.shape[0][0])])): asd = AngleSubsetDataset(d, angle_indices) ray_trafo = asd.get_ray_trafo(impl='skimage') self.assertEqual(ray_trafo.range.shape[0], len(angle_indices)) angles_subset = d.get_ray_trafo(impl='skimage').geometry.angles[ np.asarray(angle_indices)] self.assertEqual(ray_trafo.geometry.angles.shape, angles_subset.shape) self.assertTrue(np.all(ray_trafo.geometry.angles == angles_subset))
def prepare_data(self, *args, **kwargs): lodopab = dival.get_standard_dataset('lodopab', impl='astra_cpu') assert self.gt_shape <= self.IMG_SHAPE, 'GT is larger than original images.' if self.gt_shape < self.IMG_SHAPE: crop_off = (362 - self.gt_shape) // 2 gt_train = np.array([lodopab.get_sample(i, part='train', out=(False, True))[1][crop_off:-(crop_off + 1), crop_off:-(crop_off + 1)] for i in range(4000)]) gt_val = np.array([lodopab.get_sample(i, part='validation', out=(False, True))[1][crop_off:-(crop_off + 1), crop_off:-(crop_off + 1)] for i in range(400)]) gt_test = np.array([lodopab.get_sample(i, part='test', out=(False, True))[1][crop_off:-(crop_off + 1), crop_off:-(crop_off + 1)] for i in range(3553)]) else: gt_train = np.array( [lodopab.get_sample(i, part='train', out=(False, True))[1][1:, 1:] for i in range(4000)]) gt_val = np.array( [lodopab.get_sample(i, part='validation', out=(False, True))[1][1:, 1:] for i in range(400)]) gt_test = np.array( [lodopab.get_sample(i, part='test', out=(False, True))[1][1:, 1:] for i in range(3553)]) gt_train = torch.from_numpy(gt_train) gt_val = torch.from_numpy(gt_val) gt_test = torch.from_numpy(gt_test) assert gt_train.shape[1] == self.gt_shape assert gt_train.shape[2] == self.gt_shape self.mean = gt_train.mean() self.std = gt_train.std() gt_train = normalize(gt_train, self.mean, self.std) gt_val = normalize(gt_val, self.mean, self.std) gt_test = normalize(gt_test, self.mean, self.std) circle = self.__get_circle__() gt_train *= circle gt_val *= circle gt_test *= circle ds_factory = GroundTruthDatasetFactory(gt_train, gt_val, gt_test, inner_circle=self.inner_circle) self.gt_ds = ds_factory.build_projection_dataset(num_angles=self.num_angles, upscale_shape=self.gt_shape + (self.gt_shape // 2 - 7), impl='astra_cpu')
def test(self): dataset = get_standard_dataset('ellipses', fixed_seeds=True, impl='astra_cuda') ray_trafo = dataset.get_ray_trafo(impl='astra_cuda') module = TorchRayTrafoParallel2DAdjointModule(ray_trafo) for batch_size, channels in [(1, 3), (5, 1), (2, 3)]: test_data = dataset.get_data_pairs(part='train', n=batch_size * channels) torch_in = (torch.from_numpy(np.asarray( test_data.observations)).view(batch_size, channels, *dataset.shape[0])) torch_out = module(torch_in).view(-1, *dataset.shape[1]) for i, odl_in in enumerate(test_data.observations): odl_out = ray_trafo.adjoint(odl_in) self.assertTrue( np.allclose(torch_out[i].detach().cpu().numpy(), odl_out, rtol=1.e-4))
def load_standard_dataset(dataset, impl=None, ordered=False): """ Loads a Dival standard dataset. :param dataset: Name of the standard dataset :param impl: Backend for the Ray Transform :param ordered: Whether to order by patient id for 'lodopab' dataset :param angle_indices: Indices of the angles to include (default is all). :return: Dival dataset. """ if impl is None: impl = 'astra_cpu' if torch.cuda.is_available(): impl = 'astra_cuda' kwargs = {'impl': impl} if dataset == 'ellipses': kwargs['fixed_seeds'] = True # we do not use 'sorted_by_patient' here in order to be transparent to # `CachedDataset`, where a `ReorderedDataset` is handled specially # if dataset == 'lodopab': # kwargs['sorted_by_patient'] = ordered dataset_name = dataset.split('_')[0] dataset_out = get_standard_dataset(dataset_name, **kwargs) if dataset == 'lodopab_200': angles = list(range(0, 1000, 5)) dataset_out = AngleSubsetDataset(dataset_out, angles) if dataset_name == 'lodopab' and ordered: idx = get_lodopab_idx_sorted_by_patient() dataset_ordered = ReorderedDataset(dataset_out, idx) dataset_ordered.ray_trafo = dataset_out.ray_trafo dataset_ordered.get_ray_trafo = dataset_out.get_ray_trafo dataset_out = dataset_ordered return dataset_out
get_cached_fbp_dataset) from dival.reference_reconstructors import (check_for_params, download_params, get_hyper_params_path) from dival.util.plot import plot_images IMPL = 'astra_cuda' LOG_DIR = './logs/lodopab_fbpunet' SAVE_BEST_LEARNED_PARAMS_PATH = './params/lodopab_fbpunet' CACHE_FILES = { 'train': ('./cache_lodopab_train_fbp.npy', None), 'validation': ('./cache_lodopab_validation_fbp.npy', None) } dataset = get_standard_dataset('lodopab', impl=IMPL) ray_trafo = dataset.get_ray_trafo(impl=IMPL) test_data = dataset.get_data_pairs('test', 100) reconstructor = FBPUNetReconstructor( ray_trafo, log_dir=LOG_DIR, save_best_learned_params_path=SAVE_BEST_LEARNED_PARAMS_PATH) #%% obtain reference hyper parameters if not check_for_params('fbpunet', 'lodopab', include_learned=False): download_params('fbpunet', 'lodopab', include_learned=False) hyper_params_path = get_hyper_params_path('fbpunet', 'lodopab') reconstructor.load_hyper_params(hyper_params_path) #%% expose FBP cache to reconstructor by assigning `fbp_dataset` attribute
import numpy as np import matplotlib.pyplot as plt from dival import get_standard_dataset from dival.evaluation import TaskTable from dival.measure import PSNR, SSIM from dival.reference_reconstructors import get_reference_reconstructor os.environ['CUDA_VISIBLE_DEVICES'] = '0' IMPL = 'astra_cuda' DATASET = 'ellipses' np.random.seed(0) # %% data dataset = get_standard_dataset(DATASET, impl=IMPL) ray_trafo = dataset.get_ray_trafo(impl=IMPL) reco_space = ray_trafo.domain test_data = dataset.get_data_pairs('test', 100) # %% task table and reconstructors eval_tt = TaskTable() fbp_reconstructor = get_reference_reconstructor('fbp', DATASET) tvadam_reconstructor = get_reference_reconstructor('tvadam', DATASET) fbpunet_reconstructor = get_reference_reconstructor('fbpunet', DATASET) iradonmap_reconstructor = get_reference_reconstructor('iradonmap', DATASET) learnedgd_reconstructor = get_reference_reconstructor('learnedgd', DATASET) learnedpd_reconstructor = get_reference_reconstructor('learnedpd', DATASET) reconstructors = [
from dival.measure import PSNR from dival.datasets.fbp_dataset import (generate_fbp_cache_files, get_cached_fbp_dataset) from dival.reference_reconstructors import (check_for_params, download_params, get_hyper_params_path) from dival.util.plot import plot_images from self_supervised_ct.n2self import N2SelfReconstructor IMPL = 'astra_cpu' LOG_DIR = './logs/ellipses_n2self' SAVE_BEST_LEARNED_PARAMS_PATH = './params/ellipses_n2self' dataset = get_standard_dataset('ellipses', impl=IMPL) ray_trafo = dataset.get_ray_trafo(impl=IMPL) test_data = dataset.get_data_pairs('test', 100) reconstructor = N2SelfReconstructor( ray_trafo, log_dir=LOG_DIR, save_best_learned_params_path=SAVE_BEST_LEARNED_PARAMS_PATH) #%% obtain reference hyper parameters if not check_for_params('fbpunet', 'ellipses', include_learned=False): download_params('fbpunet', 'ellipses', include_learned=False) hyper_params_path = get_hyper_params_path('fbpunet', 'ellipses') reconstructor.load_hyper_params(hyper_params_path) print(reconstructor.HYPER_PARAMS) # reconstructor.lr = 0.0001
parser = argparse.ArgumentParser(description='Primal Dual stuff') parser.add_argument('model_type', default='classic') parser.add_argument('-restore', type=int, default=None) parser.add_argument('-lr', type=int, default=4) parser.add_argument('--test', action='store_true', default=False) args = parser.parse_args() channels_in = 1 channels_out = 1 batch_size = 3 msd_width = 1 msd_depth = 20 msd_dilations = [1, 2, 3, 4, 5] lr = 10**(-args.lr) dataset = get_standard_dataset('lodopab') class UnsqueezingDataset(torch.utils.data.Dataset): """Quickfix for squeezing in extra dim in dataset retrieved from dival's `get_standard_dataset`. """ def __init__(self, ds): self._ds = ds def __getitem__(self, item): sino = torch.unsqueeze(self._ds[item][0], dim=0) recon = torch.unsqueeze(self._ds[item][1], dim=0) return (sino, recon) def __add__(self, other):
dataset.dataset.dataset.forward_op = OperatorComp( odl.tomo.RayTransform( ray_trafo.domain, ray_trafo.geometry, impl=ray_trafo.impl), dataset.dataset.dataset.forward_op.right) if __name__ == '__main__': try: set_start_method('spawn') except RuntimeError: if get_start_method() != 'spawn': raise RuntimeError( 'Could not set multiprocessing start method to \'spawn\'') dataset = get_standard_dataset('ellipses', fixed_seeds=False, fixed_noise_seeds=False, impl=IMPL) ray_trafo = dataset.get_ray_trafo(impl=IMPL) patch_ray_trafo_for_pickling(ray_trafo) test_data = dataset.get_data_pairs('test', 100) reconstructor = FBPUNetReconstructor( ray_trafo, log_dir=LOG_DIR, save_best_learned_params_path=SAVE_BEST_LEARNED_PARAMS_PATH, allow_multiple_workers_without_random_access=True, num_data_loader_workers=2, # more workers lead to increased VRAM usage worker_init_fn=worker_init_fn, # for recreating forward_op ) #%% obtain reference hyper parameters if not check_for_params('fbpunet', 'ellipses', include_learned=False):
import numpy as np from dival import get_standard_dataset # NOTE: in order to run this test there are some requirements: # - dival library installed # - public lodopab dataset downloaded and configured with dival # - lodopab challenge set observations downloaded (path can be adjusted below) from lodopab_challenge.challenge_set import ( config, NUM_ANGLES, NUM_DET_PIXELS, MU_MAX, get_observation, get_observations, generator, transform_to_pre_log, replace_min_photon_count) # config['data_path'] = '/localdata/lodopab_challenge_set' lodopab = get_standard_dataset('lodopab', impl='skimage') class TestGetObservation(unittest.TestCase): def test(self): n = 3 for i, obs2 in zip(range(n), generator()): obs = get_observation(i) self.assertEqual(obs.shape, (NUM_ANGLES, NUM_DET_PIXELS)) self.assertEqual(obs.dtype, np.float32) self.assertTrue(np.all(obs == obs2)) obs = np.zeros((NUM_ANGLES, NUM_DET_PIXELS), dtype=np.float32) obs_ = get_observation(i, out=obs) self.assertIs(obs_, obs) self.assertTrue(np.all(obs == obs2)) obs = lodopab.space[0].zero() obs_ = get_observation(i, out=obs)