示例#1
0
def download_lodopab():
    global DATA_PATH
    print('Before downloading, please make sure to have enough free disk '
          'space (~150GB). After unpacking, 114.7GB will be used.')
    print("path to store LoDoPaB-CT dataset (default '{}'):".format(DATA_PATH))
    inp = input()
    if inp:
        DATA_PATH = inp
        set_config('lodopab_dataset/data_path', DATA_PATH)
    os.makedirs(DATA_PATH, exist_ok=True)
    ZENODO_RECORD_ID = '3384092'
    success = download_zenodo_record(ZENODO_RECORD_ID, DATA_PATH)
    print('download of LoDoPaB-CT dataset {}'.format(
        'successful' if success else 'failed'))
    if not success:
        return False
    file_list = [
        'observation_train.zip', 'ground_truth_train.zip',
        'observation_validation.zip', 'ground_truth_validation.zip',
        'observation_test.zip', 'ground_truth_test.zip'
    ]
    print('unzipping zip files, this can take several minutes', flush=True)
    for file in tqdm(file_list, desc='unzip'):
        filename = os.path.join(DATA_PATH, file)
        with ZipFile(filename, 'r') as f:
            f.extractall(DATA_PATH)
        os.remove(filename)
    return True
示例#2
0
    def __init__(self,
                 min_pt=None,
                 max_pt=None,
                 observation_model='post-log',
                 min_photon_count=None,
                 sorted_by_patient=False,
                 impl='astra_cuda'):
        """
        Parameters
        ----------
        min_pt : [float, float], optional
            Minimum values of the lp space. Default: ``[-0.13, -0.13]``.
        max_pt : [float, float], optional
            Maximum values of the lp space. Default: ``[0.13, 0.13]``.
        observation_model : {'post-log', 'pre-log'}, optional
            The observation model to use.
            The default is ``'post-log'``.

            ``'post-log'``
                Observations are linearly related to the normalized ground
                truth via the ray transform, ``obs = ray_trafo(gt) + noise``.
                Note that the scaling of the observations matches the
                normalized ground truth, i.e., they are divided by the linear
                attenuation of 3071 HU.
            ``'pre-log'``
                Observations are non-linearly related to the ground truth, as
                given by the Beer-Lambert law.
                The model is
                ``obs = exp(-ray_trafo(gt * MU(3071 HU))) + noise``,
                where `MU(3071 HU)` is the factor, by which the ground truth
                was normalized.
        min_photon_count : float, optional
            Replacement value for a simulated photon count of zero.
            If ``observation_model == 'post-log'``, a value greater than zero
            is required in order to avoid undefined values. The default is 0.1,
            both for ``'post-log'`` and ``'pre-log'`` model.
        sorted_by_patient : bool, optional
            Whether to sort the samples by patient id.
            Useful to resplit the dataset.
            See also :meth:`get_indices_for_patient`.
            Note that the slices of each patient are ordered randomly wrt.
            the z-location in any case.
            Default: ``False``.
        impl : {``'skimage'``, ``'astra_cpu'``, ``'astra_cuda'``},\
                optional
            Implementation passed to :class:`odl.tomo.RayTransform` to
            construct :attr:`ray_trafo`.
        """
        global DATA_PATH
        NUM_ANGLES = 1000
        NUM_DET_PIXELS = 513
        self.shape = ((NUM_ANGLES, NUM_DET_PIXELS), (362, 362))
        self.num_elements_per_sample = 2
        if min_pt is None:
            min_pt = MIN_PT
        if max_pt is None:
            max_pt = MAX_PT
        domain = uniform_discr(min_pt, max_pt, self.shape[1], dtype=np.float32)
        if observation_model == 'post-log':
            self.post_log = True
        elif observation_model == 'pre-log':
            self.post_log = False
        else:
            raise ValueError("`observation_model` must be 'post-log' or "
                             "'pre-log', not '{}'".format(observation_model))
        if min_photon_count is None or min_photon_count <= 1.:
            self.min_photon_count = min_photon_count
        else:
            self.min_photon_count = 1.
            warn('`min_photon_count` changed from {} to 1.'.format(
                min_photon_count))
        self.sorted_by_patient = sorted_by_patient
        self.train_len = LEN['train']
        self.validation_len = LEN['validation']
        self.test_len = LEN['test']
        self.random_access = True

        while not LoDoPaBDataset.check_for_lodopab():
            print('The LoDoPaB-CT dataset could not be found under the '
                  "configured path '{}'.".format(
                      CONFIG['lodopab_dataset']['data_path']))
            print('Do you want to download it now? (y: download, n: input '
                  'other path)')
            download = input_yes_no()
            if download:
                success = download_lodopab()
                if not success:
                    raise RuntimeError('lodopab dataset not available, '
                                       'download failed')
            else:
                print('Path to LoDoPaB dataset:')
                DATA_PATH = input()
                set_config('lodopab_dataset/data_path', DATA_PATH)

        self.rel_patient_ids = None
        try:
            self.rel_patient_ids = LoDoPaBDataset.get_patient_ids()
        except OSError as e:
            if self.sorted_by_patient:
                raise RuntimeError(
                    'Can not load patient ids, required for sorting. '
                    'OSError: {}'.format(e))
            warn(
                'Can not load patient ids (OSError: {}). '
                'Therefore sorting is not possible, so please keep the '
                'attribute `sorted_by_patient = False` for the LoDoPaBDataset.'
                .format(e))
        if self.rel_patient_ids is not None:
            self._idx_sorted_by_patient = (
                LoDoPaBDataset.get_idx_sorted_by_patient(self.rel_patient_ids))

        self.geometry = odl.tomo.parallel_beam_geometry(
            domain, num_angles=NUM_ANGLES, det_shape=(NUM_DET_PIXELS, ))
        range_ = uniform_discr(self.geometry.partition.min_pt,
                               self.geometry.partition.max_pt,
                               self.shape[0],
                               dtype=np.float32)
        super().__init__(space=(range_, domain))
        self.ray_trafo = self.get_ray_trafo(impl=impl)