예제 #1
0
    def init_model(self):
        self.fbp_op = fbp_op(self.op,
                             filter_type=self.filter_type,
                             frequency_scaling=self.frequency_scaling)
        self.model = UNet(in_ch=1,
                          out_ch=1,
                          channels=self.channels[:self.scales],
                          skip_channels=[self.skip_channels] * (self.scales),
                          use_sigmoid=self.use_sigmoid)
        # UNet(num_input_channels=1, num_output_channels=1,
        # feature_scale=4, more_layers=0, concat_x=False,
        # upsample_mode='bilinear', norm_layer=torch.nn.BatchNorm2d,
        # pad='reflect',
        # need_sigmoid=False, need_bias=True).to('cuda')

        self.masker = Masker(width=4, mode='interpolate')

        if self.init_bias_zero:

            def weights_init(m):
                if isinstance(m, torch.nn.Conv2d):
                    m.bias.data.fill_(0.0)

            self.model.apply(weights_init)

        if self.use_cuda:
            self.model = nn.DataParallel(self.model).to(self.device)
예제 #2
0
 def __init__(self,
              ray_trafo,
              scales=None,
              epochs=None,
              batch_size=None,
              num_data_loader_workers=8,
              use_cuda=True,
              show_pbar=True,
              fbp_impl='astra_cuda',
              **kwargs):
     """
     Parameters
     ----------
     ray_trafo : :class:`odl.tomo.RayTransform`
         Ray transform from which the FBP operator is constructed.
     scales : int, optional
         Number of scales in the U-Net (a hyper parameter).
     epochs : int, optional
         Number of epochs to train (a hyper parameter).
     batch_size : int, optional
         Batch size (a hyper parameter).
     num_data_loader_workers : int, optional
         Number of parallel workers to use for loading data.
     use_cuda : bool, optional
         Whether to use cuda for the U-Net.
     show_pbar : bool, optional
         Whether to show tqdm progress bars during the epochs.
     fbp_impl : str, optional
         The backend implementation passed to
         :class:`odl.tomo.RayTransform` in case no `ray_trafo` is specified.
         Then ``dataset.get_ray_trafo(impl=fbp_impl)`` is used to get the
         ray transform and FBP operator.
     """
     self.ray_trafo = ray_trafo
     self.fbp_op = fbp_op(self.ray_trafo)
     self.num_data_loader_workers = num_data_loader_workers
     self.use_cuda = use_cuda
     self.show_pbar = show_pbar
     self.fbp_impl = fbp_impl
     super().__init__(reco_space=self.ray_trafo.domain,
                      observation_space=self.ray_trafo.range,
                      **kwargs)
     if epochs is not None:
         self.epochs = epochs
         if kwargs.get('hyper_params', {}).get('epochs') is not None:
             warn(
                 "hyper parameter 'epochs' overridden by constructor argument"
             )
     if batch_size is not None:
         self.batch_size = batch_size
         if kwargs.get('hyper_params', {}).get('batch_size') is not None:
             warn(
                 "hyper parameter 'batch_size' overridden by constructor argument"
             )
     if scales is not None:
         self.scales = scales
         if kwargs.get('hyper_params', {}).get('scales') is not None:
             warn(
                 "hyper parameter 'scales' overridden by constructor argument"
             )
예제 #3
0
 def __init__(self,
              dataset,
              ray_trafo,
              filter_type='Hann',
              frequency_scaling=1.0):
     """
     Parameters
     ----------
     dataset : :class:`.Dataset`
         CT dataset. FBPs are computed from the observations, the ground
         truth is taken directly from the dataset.
     ray_trafo : :class:`odl.tomo.RayTransform`
         Ray transform from which the FBP operator is constructed.
     filter_type : str, optional
         Filter type accepted by :func:`odl.tomo.fbp_op`.
         Default: ``'Hann'``.
     frequency_scaling : float, optional
         Relative cutoff frequency passed to :func:`odl.tomo.fbp_op`.
         Default: ``1.0``.
     """
     self.dataset = dataset
     self.ray_trafo = ray_trafo
     self.fbp_op = fbp_op(self.ray_trafo,
                          filter_type=filter_type,
                          frequency_scaling=frequency_scaling)
     self.train_len = self.dataset.get_len('train')
     self.validation_len = self.dataset.get_len('validation')
     self.test_len = self.dataset.get_len('test')
     self.shape = (self.dataset.shape[1], self.dataset.shape[1])
     self.num_elements_per_sample = 2
     self.random_access = dataset.supports_random_access()
     super().__init__(space=(self.dataset.space[1], self.dataset.space[1]))
예제 #4
0
def generate_fbp_cache(dataset, part, filename, ray_trafo, size=None):
    """
    Write filtered back-projections for a CT dataset part to file.

    Parameters
    ----------
    dataset : :class:`.Dataset`
        CT dataset from which the observations are used.
    part : {``'train'``, ``'validation'``, ``'test'``}
        The data part.
    filename : str
        The filename to store the FBP cache at (ending ``.npy``).
    ray_trafo : :class:`odl.tomo.RayTransform`
        Ray transform from which the FBP operator is constructed.
    size : int, optional
        Number of samples to use from the dataset.
        By default, all samples are used.
    """
    fbp = fbp_op(ray_trafo)
    num_samples = dataset.get_len(part=part) if size is None else size
    reco_fbps = np.empty((num_samples, ) + dataset.shape[1], dtype=np.float32)
    obs = np.empty(dataset.shape[0], dtype=np.float32)
    tmp_fbp = fbp.range.element()
    for i in tqdm(range(num_samples), desc='generating FBP cache'):
        dataset.get_sample(i, part=part, out=(obs, False))
        fbp(obs, out=tmp_fbp)
        reco_fbps[i][:] = tmp_fbp
    np.save(filename, reco_fbps)
예제 #5
0
    def __init__(self, ray_trafo, hyper_params=None, iterations=None, gamma=None, **kwargs):
        """
        Parameters
        ----------
        ray_trafo : `odl.tomo.operators.RayTransform`
            The forward operator
        """

        super().__init__(
            reco_space=ray_trafo.domain, observation_space=ray_trafo.range,
            hyper_params=hyper_params, **kwargs)

        self.ray_trafo = ray_trafo
        self.domain_shape = ray_trafo.domain.shape
        self.opnorm = odl.power_method_opnorm(ray_trafo)
        self.fbp_op = fbp_op(
            ray_trafo, frequency_scaling=0.1, filter_type='Hann')

        if iterations is not None:
            self.iterations = iterations
            if kwargs.get('hyper_params', {}).get('iterations') is not None:
                warn("hyper parameter 'iterations' overridden by constructor argument")

        if gamma is not None:
            self.gamma = gamma
            if kwargs.get('hyper_params', {}).get('gamma') is not None:
                warn("hyper parameter 'gamma' overridden by constructor argument")
예제 #6
0
    def init_model(self):
        if self.hyper_params['init_fbp']:
            fbp = fbp_op(
                self.non_normed_ray_trafo,
                filter_type=self.hyper_params['init_filter_type'],
                frequency_scaling=self.hyper_params['init_frequency_scaling'])
            if self.normalize_by_opnorm:
                fbp = OperatorRightScalarMult(fbp, self.opnorm)
            self.init_mod = OperatorModule(fbp)
        else:
            self.init_mod = None
        self.model = PrimalDualNet(
            n_iter=self.niter,
            op=self.ray_trafo_mod,
            op_adj=self.ray_trafo_adj_mod,
            op_init=self.init_mod,
            n_primal=self.hyper_params['nprimal'],
            n_dual=self.hyper_params['ndual'],
            use_sigmoid=self.hyper_params['use_sigmoid'],
            internal_ch=self.hyper_params['internal_ch'],
            kernel_size=self.hyper_params['kernel_size'],
            batch_norm=self.hyper_params['batch_norm'],
            prelu=self.hyper_params['prelu'],
            lrelu_coeff=self.hyper_params['lrelu_coeff'])

        def weights_init(m):
            if isinstance(m, torch.nn.Conv2d):
                m.bias.data.fill_(0.0)
                torch.nn.init.xavier_uniform_(m.weight)

        self.model.apply(weights_init)

        if self.use_cuda:
            # WARNING: using data-parallel here doesn't work because of astra-gpu
            self.model = self.model.to(self.device)
예제 #7
0
 def _reconstruct(self, observation, out):
     if self.pre_processor is not None:
         observation = self.pre_processor(observation)
     if self.recompute_fbp_op:
         self.fbp_op = fbp_op(self.ray_trafo, padding=self.padding,
                              **self.hyper_params)
     self.fbp_op(observation, out=out)
     if self.post_processor is not None:
         out[:] = self.post_processor(out)
예제 #8
0
    def _reconstruct(self, observation, *args, **kwargs):
        self.fbp_op = fbp_op(self.ray_trafo,
                             filter_type=self.init_filter_type,
                             frequency_scaling=self.init_frequency_scaling)

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.output = torch.tensor(self.fbp_op(observation))[None].to(device)
        self.output.requires_grad = True

        self.optimizer = Adam([self.output], lr=self.lr)

        y_delta = torch.tensor(np.asarray(observation), dtype=torch.float32)
        y_delta = y_delta.view(1, *y_delta.shape)
        y_delta = y_delta.to(device)

        if self.loss_function == 'mse':
            criterion = MSELoss()
        elif self.loss_function == 'poisson':
            criterion = partial(poisson_loss,
                                photons_per_pixel=self.photons_per_pixel,
                                mu_max=self.mu_max)
        else:
            warn('Unknown loss function, falling back to MSE')
            criterion = MSELoss()

        best_loss = np.infty
        best_output = self.output.detach().clone()

        for i in tqdm(range(self.iterations),
                      desc='TV',
                      disable=not self.show_pbar):
            self.optimizer.zero_grad()
            loss = criterion(self.ray_trafo_module(self.output),
                             y_delta) + self.gamma * tv_loss(self.output)
            loss.backward()

            self.optimizer.step()

            if loss.item() < best_loss:
                best_loss = loss.item()
                best_output = self.output.detach().clone()

            if (self.callback_func is not None
                    and (i % self.callback_func_interval == 0
                         or i == self.iterations - 1)):
                self.callback_func(
                    iteration=i,
                    reconstruction=best_output[0, ...].cpu().numpy(),
                    loss=best_loss)

            if self.callback is not None:
                self.callback(
                    self.reco_space.element(best_output[0, ...].cpu().numpy()))

        return self.reco_space.element(best_output[0, ...].cpu().numpy())
예제 #9
0
def generate_dataset_cache(dataset,
                           part,
                           file_names,
                           ray_trafo,
                           filter_type='Hann',
                           frequency_scaling=1.0,
                           size=None,
                           only_fbp=False):
    """
    Write data-paris for a CT dataset part to file.

    Parameters
    ----------
    dataset : :class:`.Dataset`
        CT dataset from which the observations are used.
    part : {``'train'``, ``'validation'``, ``'test'``}
        The data part.
    file_name : str
        The filenames to store the cache at (ending ``.npy``).
    ray_trafo : :class:`odl.tomo.RayTransform`
        Ray transform from which the FBP operator is constructed.
    size : int, optional
        Number of samples to use from the dataset.
        By default, all samples are used.
    """
    fbp = fbp_op(ray_trafo,
                 filter_type=filter_type,
                 frequency_scaling=frequency_scaling)
    num_samples = dataset.get_len(part=part) if size is None else size

    if not only_fbp:
        gts_data = np.empty((num_samples, ) + ray_trafo.domain.shape,
                            dtype=np.float32)
        obs_data = np.empty((num_samples, ) + ray_trafo.range.shape,
                            dtype=np.float32)

    fbp_data = np.empty((num_samples, ) + ray_trafo.domain.shape,
                        dtype=np.float32)

    tmp_fbp = fbp.range.element()
    for i, (obs, gt) in zip(tqdm(range(num_samples), desc='generating cache'),
                            dataset.generator(part)):
        fbp(obs, out=tmp_fbp)

        fbp_data[i, ...] = tmp_fbp

        if not only_fbp:
            obs_data[i, ...] = obs
            gts_data[i, ...] = gt

    if only_fbp:
        np.save(file_names[0], fbp_data)
    else:
        for filename, data in zip(file_names, [gts_data, obs_data, fbp_data]):
            np.save(filename, data)
예제 #10
0
 def _reconstruct(self, observation, out):
     if self.pre_processor is not None:
         observation = self.pre_processor(observation)
     if self.recompute_fbp_op:
         self.fbp_op = fbp_op(self.ray_trafo,
                              padding=self.padding,
                              **self.hyper_params)
     if out in self.reco_space:
         self.fbp_op(observation, out=out)
     else:  # out is e.g. numpy array, cannot be passed to fbp_op
         out[:] = self.fbp_op(observation)
     if self.post_processor is not None:
         out[:] = self.post_processor(out)
예제 #11
0
    def init_model(self):
        self.fbp_op = fbp_op(self.op, filter_type=self.filter_type,
                             frequency_scaling=self.frequency_scaling)
        self.model = UNet(in_ch=1, out_ch=1,
                          channels=self.channels[:self.scales],
                          skip_channels=[self.skip_channels] * (self.scales),
                          use_sigmoid=self.use_sigmoid)
        if self.init_bias_zero:
            def weights_init(m):
                if isinstance(m, torch.nn.Conv2d):
                    m.bias.data.fill_(0.0)
            self.model.apply(weights_init)

        if self.use_cuda:
            self.model = nn.DataParallel(self.model).to(self.device)
예제 #12
0
    def init_model(self):
        self.op_mod = OperatorModule(self.op)
        self.op_adj_mod = OperatorModule(self.op.adjoint)
        partial0 = odl.PartialDerivative(self.op.domain, axis=0)
        partial1 = odl.PartialDerivative(self.op.domain, axis=1)
        self.reg_mod = OperatorModule(partial0.adjoint * partial0 +
                                      partial1.adjoint * partial1)
        if self.hyper_params['init_fbp']:
            fbp = fbp_op(
                self.non_normed_op,
                filter_type=self.hyper_params['init_filter_type'],
                frequency_scaling=self.hyper_params['init_frequency_scaling'])
            if self.normalize_by_opnorm:
                fbp = OperatorRightScalarMult(fbp, self.opnorm)
            self.init_mod = OperatorModule(fbp)
        else:
            self.init_mod = None
        self.model = IterativeNet(n_iter=self.niter,
                                  n_memory=5,
                                  op=self.op_mod,
                                  op_adj=self.op_adj_mod,
                                  op_init=self.init_mod,
                                  op_reg=self.reg_mod,
                                  use_sigmoid=self.hyper_params['use_sigmoid'],
                                  n_layer=self.hyper_params['nlayer'],
                                  internal_ch=self.hyper_params['internal_ch'],
                                  kernel_size=self.hyper_params['kernel_size'],
                                  batch_norm=self.hyper_params['batch_norm'],
                                  prelu=self.hyper_params['prelu'],
                                  lrelu_coeff=self.hyper_params['lrelu_coeff'])

        def weights_init(m):
            if isinstance(m, torch.nn.Conv2d):
                m.bias.data.fill_(0.0)
                if self.hyper_params['init_weight_xavier_normal']:
                    torch.nn.init.xavier_normal_(
                        m.weight, gain=self.hyper_params['init_weight_gain'])

        self.model.apply(weights_init)

        if self.use_cuda:
            # WARNING: using data-parallel here doesn't work, probably
            # astra_cuda is not thread-safe
            self.model = self.model.to(self.device)
예제 #13
0
 def __init__(self, dataset, ray_trafo):
     """
     Parameters
     ----------
     dataset : :class:`.Dataset`
         CT dataset. FBPs are computed from the observations, the ground
         truth is taken directly from the dataset.
     ray_trafo : :class:`odl.tomo.RayTransform`
         Ray transform from which the FBP operator is constructed.
     """
     self.dataset = dataset
     self.fbp_op = fbp_op(ray_trafo)
     self.train_len = self.dataset.get_len('train')
     self.validation_len = self.dataset.get_len('validation')
     self.test_len = self.dataset.get_len('test')
     self.shape = (self.dataset.shape[1], self.dataset.shape[1])
     self.num_elements_per_sample = 2
     self.random_access = dataset.supports_random_access()
     super().__init__(space=(self.dataset.space[1], self.dataset.space[1]))
예제 #14
0
    def init_model(self):
        self.fbp_op = fbp_op(self.ray_trafo,
                             filter_type=self.filter_type,
                             frequency_scaling=self.frequency_scaling)
        self.model = get_unet_model(scales=self.scales,
                                    skip=self.skip_channels,
                                    channels=self.channels,
                                    use_sigmoid=self.use_sigmoid)

        if self.init_bias_zero:

            def weights_init(m):
                if isinstance(m, torch.nn.Conv2d):
                    m.bias.data.fill_(0.0)

            self.model.apply(weights_init)

        if self.use_cuda:
            self.model = nn.DataParallel(self.model).to(self.device)
예제 #15
0
 def __init__(self,
              ray_trafo,
              padding=True,
              hyper_params=None,
              pre_processor=None,
              post_processor=None,
              recompute_fbp_op=True,
              **kwargs):
     """ 
     Parameters
     ----------
     ray_trafo : `odl.tomo.operators.RayTransform`
         The forward operator. See `odl.tomo.fbp_op` for details.
     padding : bool, optional
         Whether to use padding (the default is ``True``).
         See `odl.tomo.fbp_op` for details.
     pre_processor : callable, optional
         Callable that takes the observation and returns the sinogram that
         is passed to the filtered back-projection operator.
     post_processor : callable, optional
         Callable that takes the filtered back-projection and returns the
         final reconstruction.
     recompute_fbp_op : bool, optional
         Whether :attr:`fbp_op` should be recomputed on each call to
         :meth:`reconstruct`. Must be ``True`` (default) if changes to
         :attr:`ray_trafo`, :attr:`hyper_params` or :attr:`padding` are
         planned in order to use the updated values in :meth:`reconstruct`.
         If none of these attributes will change, you may specify
         ``recompute_fbp_op==False``, so :attr:`fbp_op` can be computed
         only once, improving reconstruction time efficiency.
     """
     self.ray_trafo = ray_trafo
     self.padding = padding
     self.pre_processor = pre_processor
     self.post_processor = post_processor
     super().__init__(reco_space=ray_trafo.domain,
                      observation_space=ray_trafo.range,
                      hyper_params=hyper_params,
                      **kwargs)
     self.fbp_op = fbp_op(self.ray_trafo,
                          padding=self.padding,
                          **self.hyper_params)
     self.recompute_fbp_op = recompute_fbp_op
예제 #16
0
    def __init__(self, ray_trafo, hyper_params=None, callback=None,
                 callback_func=None, callback_func_interval=100, **kwargs):
        """
        Parameters
        ----------
        ray_trafo : `odl.tomo.operators.RayTransform`
            The forward operator
        """

        super().__init__(
            reco_space=ray_trafo.domain, observation_space=ray_trafo.range,
            hyper_params=hyper_params, callback=callback, **kwargs)

        self.fbp_op = fbp_op(
            ray_trafo, frequency_scaling=0.1, filter_type='Hann')
        self.callback_func = callback_func
        self.ray_trafo = ray_trafo
        self.ray_trafo_module = OperatorModule(self.ray_trafo)
        self.domain_shape = ray_trafo.domain.shape
        self.callback_func = callback_func
        self.callback_func_interval = callback_func_interval
예제 #17
0
    def __init__(self,
                 ray_trafo,
                 filter_type=None,
                 frequency_scaling=None,
                 scales=None,
                 epochs=None,
                 batch_size=None,
                 lr=None,
                 skip_channels=None,
                 num_data_loader_workers=8,
                 use_cuda=True,
                 show_pbar=True,
                 fbp_impl='astra_cuda',
                 hyper_params=None,
                 **kwargs):
        """
        Parameters
        ----------
        ray_trafo : :class:`odl.tomo.RayTransform`
            Ray transform from which the FBP operator is constructed.
        scales : int, optional
            Number of scales in the U-Net (a hyper parameter).
        epochs : int, optional
            Number of epochs to train (a hyper parameter).
        batch_size : int, optional
            Batch size (a hyper parameter).
        num_data_loader_workers : int, optional
            Number of parallel workers to use for loading data.
        use_cuda : bool, optional
            Whether to use cuda for the U-Net.
        show_pbar : bool, optional
            Whether to show tqdm progress bars during the epochs.
        fbp_impl : str, optional
            The backend implementation passed to
            :class:`odl.tomo.RayTransform` in case no `ray_trafo` is specified.
            Then ``dataset.get_ray_trafo(impl=fbp_impl)`` is used to get the
            ray transform and FBP operator.
        """

        super().__init__(ray_trafo,
                         epochs=epochs,
                         batch_size=batch_size,
                         lr=lr,
                         num_data_loader_workers=num_data_loader_workers,
                         use_cuda=use_cuda,
                         show_pbar=show_pbar,
                         fbp_impl=fbp_impl,
                         hyper_params=hyper_params,
                         **kwargs)

        if scales is not None:
            self.scales = scales
            if kwargs.get('hyper_params', {}).get('scales') is not None:
                warn(
                    "hyper parameter 'scales' overridden by constructor argument"
                )

        if skip_channels is not None:
            self.skip_channels = skip_channels
            if kwargs.get('hyper_params', {}).get('skip_channels') is not None:
                warn(
                    "hyper parameter 'skip_channels' overridden by constructor argument"
                )

        if filter_type is not None:
            self.filter_type = filter_type
            if kwargs.get('hyper_params', {}).get('filter_type') is not None:
                warn(
                    "hyper parameter 'filter_type' overridden by constructor argument"
                )

        if frequency_scaling is not None:
            self.frequency_scaling = frequency_scaling
            if kwargs.get('hyper_params',
                          {}).get('frequency_scaling') is not None:
                warn(
                    "hyper parameter 'frequency_scaling' overridden by constructor argument"
                )

        # TODO: update fbp_op when the hyper parameters change?
        self.fbp_op = fbp_op(ray_trafo,
                             filter_type=self.filter_type,
                             frequency_scaling=self.frequency_scaling)