def get_reference_reconstructor(reconstructor_key_name_or_type, dataset_name, pretrained=True, **kwargs): """ Return a reference reconstructor. Parameters ---------- reconstructor_key_name_or_type : str or type Key name of configuration or reconstructor type. dataset_name : str Standard dataset name. pretrained : bool, optional Whether learned parameters should be loaded (if any). Default: `True`. kwargs : dict Keyword arguments (passed to :func:`construct_reconstructor`). For CT configurations this includes the ``'impl'`` used by :class:`odl.tomo.RayTransform`. Raises ------ RuntimeError If parameter files are missing and the user chooses not to download. Returns ------- reconstructor : :class:`Reconstructor` The reference reconstructor. """ r_key_name, r_type = validate_reconstructor_key_name_or_type( reconstructor_key_name_or_type, dataset_name) params_exist, missing = check_for_params(r_key_name, dataset_name, include_learned=pretrained, return_missing=True) if not params_exist: print("Reference configuration '{}' for dataset '{}' not found at the " "configured path '{}'. You can change this path with " "``dival.config.set_config('reference_params/datapath', ...)``.". format(r_key_name, dataset_name, DATA_PATH)) print('Missing files are: {}.'.format(missing)) print('Do you want to download it now? (y: download, n: cancel)') download = input_yes_no() if not download: raise RuntimeError('Reference configuration missing, cancelled') download_params(r_key_name, dataset_name) reconstructor = construct_reconstructor(r_key_name, dataset_name, **kwargs) params_path = get_params_path(r_key_name, dataset_name) reconstructor.load_hyper_params(params_path + '_hyper_params.json') if pretrained and issubclass(r_type, LearnedReconstructor): reconstructor.load_learned_params(params_path) return reconstructor
def download_zenodo_record(record_id, base_path='', md5sum_check=True): """ Download a zenodo record. Parameters ---------- record_id : str Record id. base_path : str, optional Path to store the downloaded files in. Default is the current folder. md5sum_check : bool, optional Whether to check the MD5 sum of each downloaded file. Returns ------- success : bool If ``md5sum_check=True``, whether all sums matched. Otherwise the returned value is always ``True``. """ r = requests.get('https://zenodo.org/api/records/{}'.format(record_id)) files = r.json()['files'] success = True for i, f in enumerate(files): url = f['links']['self'] filename = f['key'] size_kb = f['size'] / 1000 checksum = f['checksum'] print("downloading file {:d}/{:d}: '{}', {}KB".format( i+1, len(files), filename, size_kb)) if md5sum_check: retry = True md5sum_matches = False while retry and not md5sum_matches: md5sum = download_file(url, os.path.join(base_path, filename), md5sum=True) md5sum_matches = (md5sum == checksum.split(':')[1]) if not md5sum_matches: print("md5 checksum does not match for file '{}'. Retry " "downloading? (y)/n".format(filename)) retry = input_yes_no() if not md5sum_matches: success = False print('record download aborted') break else: download_file(url, os.path.join(base_path, filename), md5sum=False) return success
def download_zenodo_record(record_id, base_path='', md5sum_check=True, auto_yes=False): """ Download a zenodo record. Unfortunately, downloads cannot be resumed, so this method is only recommended for stable internet connections. Parameters ---------- record_id : str Record id. base_path : str, optional Path to store the downloaded files in. Default is the current folder. md5sum_check : bool, optional Whether to check the MD5 sum of each downloaded file. Default: `True` auto_yes : bool, optional Whether to answer user input questions with "y" by default. User input questions are: If ``md5sum_check=True``, in case of a checksum mismatch: whether to retry downloading. If ``md5sum_check=False``, in case of an existing file of correct size: whether to re-download. Default: `False` Returns ------- success : bool If ``md5sum_check=True``, whether all sums matched. Otherwise the returned value is always ``True``. """ r = requests.get('https://zenodo.org/api/records/{}'.format(record_id)) files = r.json()['files'] success = True for i, f in enumerate(files): url = f['links']['self'] filename = f['key'] path = os.path.join(base_path, filename) size = f['size'] size_kb = size / 1000 checksum = f['checksum'] try: size_existing = os.stat(path).st_size except OSError: size_existing = -1 if size_existing == size: if md5sum_check: print("File {:d}/{:d}, '{}', {}KB already exists with correct " "size. Will check md5 sum now.".format( i+1, len(files), filename, size_kb)) md5sum_existing = compute_md5sum(path) md5sum_matches = (md5sum_existing == checksum.split(':')[1]) if md5sum_matches: print("skipping file {}, md5 checksum matches".format( filename)) continue else: print("existing file {} will be overwritten, md5 checksum " "does not match".format(filename)) else: print("File {:d}/{:d}, '{}', {}KB already exists with correct " "size. Re-download this file? (y)/n".format( i+1, len(files), filename, size_kb)) if auto_yes: print("y") download = True else: download = input_yes_no() if not download: print("skipping existing file {}".format(filename)) continue print("downloading file {:d}/{:d}: '{}', {}KB".format( i+1, len(files), filename, size_kb)) if md5sum_check: retry = True md5sum_matches = False while retry and not md5sum_matches: md5sum = download_file(url, path, md5sum=True) md5sum_matches = (md5sum == checksum.split(':')[1]) if not md5sum_matches: print("md5 checksum does not match for file '{}'. Retry " "downloading? (y)/n".format(filename)) if auto_yes: print("y") retry = True else: retry = input_yes_no() if not md5sum_matches: success = False print('record download aborted') break else: download_file(url, os.path.join(base_path, filename), md5sum=False) return success
def __init__(self, min_pt=None, max_pt=None, observation_model='post-log', min_photon_count=None, sorted_by_patient=False, impl='astra_cuda'): """ Parameters ---------- min_pt : [float, float], optional Minimum values of the lp space. Default: ``[-0.13, -0.13]``. max_pt : [float, float], optional Maximum values of the lp space. Default: ``[0.13, 0.13]``. observation_model : {'post-log', 'pre-log'}, optional The observation model to use. The default is ``'post-log'``. ``'post-log'`` Observations are linearly related to the normalized ground truth via the ray transform, ``obs = ray_trafo(gt) + noise``. Note that the scaling of the observations matches the normalized ground truth, i.e., they are divided by the linear attenuation of 3071 HU. ``'pre-log'`` Observations are non-linearly related to the ground truth, as given by the Beer-Lambert law. The model is ``obs = exp(-ray_trafo(gt * MU(3071 HU))) + noise``, where `MU(3071 HU)` is the factor, by which the ground truth was normalized. min_photon_count : float, optional Replacement value for a simulated photon count of zero. If ``observation_model == 'post-log'``, a value greater than zero is required in order to avoid undefined values. The default is 0.1, both for ``'post-log'`` and ``'pre-log'`` model. sorted_by_patient : bool, optional Whether to sort the samples by patient id. Useful to resplit the dataset. See also :meth:`get_indices_for_patient`. Note that the slices of each patient are ordered randomly wrt. the z-location in any case. Default: ``False``. impl : {``'skimage'``, ``'astra_cpu'``, ``'astra_cuda'``},\ optional Implementation passed to :class:`odl.tomo.RayTransform` to construct :attr:`ray_trafo`. """ global DATA_PATH NUM_ANGLES = 1000 NUM_DET_PIXELS = 513 self.shape = ((NUM_ANGLES, NUM_DET_PIXELS), (362, 362)) self.num_elements_per_sample = 2 if min_pt is None: min_pt = MIN_PT if max_pt is None: max_pt = MAX_PT domain = uniform_discr(min_pt, max_pt, self.shape[1], dtype=np.float32) if observation_model == 'post-log': self.post_log = True elif observation_model == 'pre-log': self.post_log = False else: raise ValueError("`observation_model` must be 'post-log' or " "'pre-log', not '{}'".format(observation_model)) if min_photon_count is None or min_photon_count <= 1.: self.min_photon_count = min_photon_count else: self.min_photon_count = 1. warn('`min_photon_count` changed from {} to 1.'.format( min_photon_count)) self.sorted_by_patient = sorted_by_patient self.train_len = LEN['train'] self.validation_len = LEN['validation'] self.test_len = LEN['test'] self.random_access = True while not LoDoPaBDataset.check_for_lodopab(): print('The LoDoPaB-CT dataset could not be found under the ' "configured path '{}'.".format( CONFIG['lodopab_dataset']['data_path'])) print('Do you want to download it now? (y: download, n: input ' 'other path)') download = input_yes_no() if download: success = download_lodopab() if not success: raise RuntimeError('lodopab dataset not available, ' 'download failed') else: print('Path to LoDoPaB dataset:') DATA_PATH = input() set_config('lodopab_dataset/data_path', DATA_PATH) self.rel_patient_ids = None try: self.rel_patient_ids = LoDoPaBDataset.get_patient_ids() except OSError as e: if self.sorted_by_patient: raise RuntimeError( 'Can not load patient ids, required for sorting. ' 'OSError: {}'.format(e)) warn( 'Can not load patient ids (OSError: {}). ' 'Therefore sorting is not possible, so please keep the ' 'attribute `sorted_by_patient = False` for the LoDoPaBDataset.' .format(e)) if self.rel_patient_ids is not None: self._idx_sorted_by_patient = ( LoDoPaBDataset.get_idx_sorted_by_patient(self.rel_patient_ids)) self.geometry = odl.tomo.parallel_beam_geometry( domain, num_angles=NUM_ANGLES, det_shape=(NUM_DET_PIXELS, )) range_ = uniform_discr(self.geometry.partition.min_pt, self.geometry.partition.max_pt, self.shape[0], dtype=np.float32) super().__init__(space=(range_, domain)) self.ray_trafo = self.get_ray_trafo(impl=impl)
ax[1].set_title('Ground truth') psnr = PSNR(reco, gt) ssim = SSIM(reco, gt) ax[0].set_xlabel('PSNR: {:.2f}dB, SSIM: {:.3f}'.format(psnr, ssim)) print('metrics for FBP reconstruction on sample {:d}:'.format(i)) print('PSNR: {:.2f}dB, SSIM: {:.3f}'.format(psnr, ssim)) plt.show() # %% simulate and store fan beam observations SKIP_SIMULATION = False if not SKIP_SIMULATION: from dival.util.input import input_yes_no print('start simulating and storing fan beam observations for all lodopab ' 'ground truth samples? [y]/n') if not input_yes_no(): raise RuntimeError('cancelled by user') obs_shape = dataset.ray_trafo.range.shape for part in ['train', 'validation', 'test']: for i, (obs, gt) in enumerate( tqdm(dataset.generator(part), desc='simulating part \'{}\''.format(part), total=dataset.get_len(part))): filenumber = i // NUM_SAMPLES_PER_FILE idx_in_file = i % NUM_SAMPLES_PER_FILE obs_filename = os.path.join( DATA_PATH, '{}_{}_{:03d}.hdf5'.format(OBSERVATION_NAME, part, filenumber)) with h5py.File(obs_filename, 'a') as observation_file: observation_dataset = observation_file.require_dataset(