def get_reference_reconstructor(reconstructor_key_name_or_type,
                                dataset_name,
                                pretrained=True,
                                **kwargs):
    """
    Return a reference reconstructor.

    Parameters
    ----------
    reconstructor_key_name_or_type : str or type
        Key name of configuration or reconstructor type.
    dataset_name : str
        Standard dataset name.
    pretrained : bool, optional
        Whether learned parameters should be loaded (if any).
        Default: `True`.
    kwargs : dict
        Keyword arguments (passed to :func:`construct_reconstructor`).
        For CT configurations this includes the ``'impl'`` used by
        :class:`odl.tomo.RayTransform`.

    Raises
    ------
    RuntimeError
        If parameter files are missing and the user chooses not to download.

    Returns
    -------
    reconstructor : :class:`Reconstructor`
        The reference reconstructor.
    """
    r_key_name, r_type = validate_reconstructor_key_name_or_type(
        reconstructor_key_name_or_type, dataset_name)
    params_exist, missing = check_for_params(r_key_name,
                                             dataset_name,
                                             include_learned=pretrained,
                                             return_missing=True)
    if not params_exist:
        print("Reference configuration '{}' for dataset '{}' not found at the "
              "configured path '{}'. You can change this path with "
              "``dival.config.set_config('reference_params/datapath', ...)``.".
              format(r_key_name, dataset_name, DATA_PATH))
        print('Missing files are: {}.'.format(missing))
        print('Do you want to download it now? (y: download, n: cancel)')
        download = input_yes_no()
        if not download:
            raise RuntimeError('Reference configuration missing, cancelled')
        download_params(r_key_name, dataset_name)
    reconstructor = construct_reconstructor(r_key_name, dataset_name, **kwargs)
    params_path = get_params_path(r_key_name, dataset_name)
    reconstructor.load_hyper_params(params_path + '_hyper_params.json')
    if pretrained and issubclass(r_type, LearnedReconstructor):
        reconstructor.load_learned_params(params_path)
    return reconstructor
Exemple #2
0
def download_zenodo_record(record_id, base_path='', md5sum_check=True):
    """
    Download a zenodo record.

    Parameters
    ----------
    record_id : str
        Record id.
    base_path : str, optional
        Path to store the downloaded files in. Default is the current folder.
    md5sum_check : bool, optional
        Whether to check the MD5 sum of each downloaded file.

    Returns
    -------
    success : bool
        If ``md5sum_check=True``, whether all sums matched. Otherwise the
        returned value is always ``True``.
    """
    r = requests.get('https://zenodo.org/api/records/{}'.format(record_id))
    files = r.json()['files']
    success = True
    for i, f in enumerate(files):
        url = f['links']['self']
        filename = f['key']
        size_kb = f['size'] / 1000
        checksum = f['checksum']
        print("downloading file {:d}/{:d}: '{}', {}KB".format(
            i+1, len(files), filename, size_kb))
        if md5sum_check:
            retry = True
            md5sum_matches = False
            while retry and not md5sum_matches:
                md5sum = download_file(url, os.path.join(base_path, filename),
                                       md5sum=True)
                md5sum_matches = (md5sum == checksum.split(':')[1])
                if not md5sum_matches:
                    print("md5 checksum does not match for file '{}'. Retry "
                          "downloading? (y)/n".format(filename))
                    retry = input_yes_no()
            if not md5sum_matches:
                success = False
                print('record download aborted')
                break
        else:
            download_file(url, os.path.join(base_path, filename), md5sum=False)
    return success
Exemple #3
0
def download_zenodo_record(record_id, base_path='', md5sum_check=True,
                           auto_yes=False):
    """
    Download a zenodo record.
    Unfortunately, downloads cannot be resumed, so this method is only
    recommended for stable internet connections.

    Parameters
    ----------
    record_id : str
        Record id.
    base_path : str, optional
        Path to store the downloaded files in. Default is the current folder.
    md5sum_check : bool, optional
        Whether to check the MD5 sum of each downloaded file.
        Default: `True`
    auto_yes : bool, optional
        Whether to answer user input questions with "y" by default.
        User input questions are:
        If ``md5sum_check=True``, in case of a checksum mismatch: whether to
        retry downloading.
        If ``md5sum_check=False``, in case of an existing file of correct size:
        whether to re-download.
        Default: `False`

    Returns
    -------
    success : bool
        If ``md5sum_check=True``, whether all sums matched. Otherwise the
        returned value is always ``True``.
    """
    r = requests.get('https://zenodo.org/api/records/{}'.format(record_id))
    files = r.json()['files']
    success = True
    for i, f in enumerate(files):
        url = f['links']['self']
        filename = f['key']
        path = os.path.join(base_path, filename)
        size = f['size']
        size_kb = size / 1000
        checksum = f['checksum']
        try:
            size_existing = os.stat(path).st_size
        except OSError:
            size_existing = -1
        if size_existing == size:
            if md5sum_check:
                print("File {:d}/{:d}, '{}', {}KB already exists with correct "
                      "size. Will check md5 sum now.".format(
                          i+1, len(files), filename, size_kb))
                md5sum_existing = compute_md5sum(path)
                md5sum_matches = (md5sum_existing == checksum.split(':')[1])
                if md5sum_matches:
                    print("skipping file {}, md5 checksum matches".format(
                        filename))
                    continue
                else:
                    print("existing file {} will be overwritten, md5 checksum "
                          "does not match".format(filename))
            else:
                print("File {:d}/{:d}, '{}', {}KB already exists with correct "
                      "size. Re-download this file? (y)/n".format(
                          i+1, len(files), filename, size_kb))
                if auto_yes:
                    print("y")
                    download = True
                else:
                    download = input_yes_no()
                if not download:
                    print("skipping existing file {}".format(filename))
                    continue
        print("downloading file {:d}/{:d}: '{}', {}KB".format(
            i+1, len(files), filename, size_kb))
        if md5sum_check:
            retry = True
            md5sum_matches = False
            while retry and not md5sum_matches:
                md5sum = download_file(url, path, md5sum=True)
                md5sum_matches = (md5sum == checksum.split(':')[1])
                if not md5sum_matches:
                    print("md5 checksum does not match for file '{}'. Retry "
                          "downloading? (y)/n".format(filename))
                    if auto_yes:
                        print("y")
                        retry = True
                    else:
                        retry = input_yes_no()
            if not md5sum_matches:
                success = False
                print('record download aborted')
                break
        else:
            download_file(url, os.path.join(base_path, filename), md5sum=False)
    return success
Exemple #4
0
    def __init__(self,
                 min_pt=None,
                 max_pt=None,
                 observation_model='post-log',
                 min_photon_count=None,
                 sorted_by_patient=False,
                 impl='astra_cuda'):
        """
        Parameters
        ----------
        min_pt : [float, float], optional
            Minimum values of the lp space. Default: ``[-0.13, -0.13]``.
        max_pt : [float, float], optional
            Maximum values of the lp space. Default: ``[0.13, 0.13]``.
        observation_model : {'post-log', 'pre-log'}, optional
            The observation model to use.
            The default is ``'post-log'``.

            ``'post-log'``
                Observations are linearly related to the normalized ground
                truth via the ray transform, ``obs = ray_trafo(gt) + noise``.
                Note that the scaling of the observations matches the
                normalized ground truth, i.e., they are divided by the linear
                attenuation of 3071 HU.
            ``'pre-log'``
                Observations are non-linearly related to the ground truth, as
                given by the Beer-Lambert law.
                The model is
                ``obs = exp(-ray_trafo(gt * MU(3071 HU))) + noise``,
                where `MU(3071 HU)` is the factor, by which the ground truth
                was normalized.
        min_photon_count : float, optional
            Replacement value for a simulated photon count of zero.
            If ``observation_model == 'post-log'``, a value greater than zero
            is required in order to avoid undefined values. The default is 0.1,
            both for ``'post-log'`` and ``'pre-log'`` model.
        sorted_by_patient : bool, optional
            Whether to sort the samples by patient id.
            Useful to resplit the dataset.
            See also :meth:`get_indices_for_patient`.
            Note that the slices of each patient are ordered randomly wrt.
            the z-location in any case.
            Default: ``False``.
        impl : {``'skimage'``, ``'astra_cpu'``, ``'astra_cuda'``},\
                optional
            Implementation passed to :class:`odl.tomo.RayTransform` to
            construct :attr:`ray_trafo`.
        """
        global DATA_PATH
        NUM_ANGLES = 1000
        NUM_DET_PIXELS = 513
        self.shape = ((NUM_ANGLES, NUM_DET_PIXELS), (362, 362))
        self.num_elements_per_sample = 2
        if min_pt is None:
            min_pt = MIN_PT
        if max_pt is None:
            max_pt = MAX_PT
        domain = uniform_discr(min_pt, max_pt, self.shape[1], dtype=np.float32)
        if observation_model == 'post-log':
            self.post_log = True
        elif observation_model == 'pre-log':
            self.post_log = False
        else:
            raise ValueError("`observation_model` must be 'post-log' or "
                             "'pre-log', not '{}'".format(observation_model))
        if min_photon_count is None or min_photon_count <= 1.:
            self.min_photon_count = min_photon_count
        else:
            self.min_photon_count = 1.
            warn('`min_photon_count` changed from {} to 1.'.format(
                min_photon_count))
        self.sorted_by_patient = sorted_by_patient
        self.train_len = LEN['train']
        self.validation_len = LEN['validation']
        self.test_len = LEN['test']
        self.random_access = True

        while not LoDoPaBDataset.check_for_lodopab():
            print('The LoDoPaB-CT dataset could not be found under the '
                  "configured path '{}'.".format(
                      CONFIG['lodopab_dataset']['data_path']))
            print('Do you want to download it now? (y: download, n: input '
                  'other path)')
            download = input_yes_no()
            if download:
                success = download_lodopab()
                if not success:
                    raise RuntimeError('lodopab dataset not available, '
                                       'download failed')
            else:
                print('Path to LoDoPaB dataset:')
                DATA_PATH = input()
                set_config('lodopab_dataset/data_path', DATA_PATH)

        self.rel_patient_ids = None
        try:
            self.rel_patient_ids = LoDoPaBDataset.get_patient_ids()
        except OSError as e:
            if self.sorted_by_patient:
                raise RuntimeError(
                    'Can not load patient ids, required for sorting. '
                    'OSError: {}'.format(e))
            warn(
                'Can not load patient ids (OSError: {}). '
                'Therefore sorting is not possible, so please keep the '
                'attribute `sorted_by_patient = False` for the LoDoPaBDataset.'
                .format(e))
        if self.rel_patient_ids is not None:
            self._idx_sorted_by_patient = (
                LoDoPaBDataset.get_idx_sorted_by_patient(self.rel_patient_ids))

        self.geometry = odl.tomo.parallel_beam_geometry(
            domain, num_angles=NUM_ANGLES, det_shape=(NUM_DET_PIXELS, ))
        range_ = uniform_discr(self.geometry.partition.min_pt,
                               self.geometry.partition.max_pt,
                               self.shape[0],
                               dtype=np.float32)
        super().__init__(space=(range_, domain))
        self.ray_trafo = self.get_ray_trafo(impl=impl)
Exemple #5
0
    ax[1].set_title('Ground truth')
    psnr = PSNR(reco, gt)
    ssim = SSIM(reco, gt)
    ax[0].set_xlabel('PSNR: {:.2f}dB, SSIM: {:.3f}'.format(psnr, ssim))
    print('metrics for FBP reconstruction on sample {:d}:'.format(i))
    print('PSNR: {:.2f}dB, SSIM: {:.3f}'.format(psnr, ssim))
plt.show()

# %% simulate and store fan beam observations
SKIP_SIMULATION = False

if not SKIP_SIMULATION:
    from dival.util.input import input_yes_no
    print('start simulating and storing fan beam observations for all lodopab '
          'ground truth samples? [y]/n')
    if not input_yes_no():
        raise RuntimeError('cancelled by user')

    obs_shape = dataset.ray_trafo.range.shape
    for part in ['train', 'validation', 'test']:
        for i, (obs, gt) in enumerate(
                tqdm(dataset.generator(part),
                     desc='simulating part \'{}\''.format(part),
                     total=dataset.get_len(part))):
            filenumber = i // NUM_SAMPLES_PER_FILE
            idx_in_file = i % NUM_SAMPLES_PER_FILE
            obs_filename = os.path.join(
                DATA_PATH, '{}_{}_{:03d}.hdf5'.format(OBSERVATION_NAME, part,
                                                      filenumber))
            with h5py.File(obs_filename, 'a') as observation_file:
                observation_dataset = observation_file.require_dataset(