Esempio n. 1
0
    def __init__(self, ray_trafo, niter=None, **kwargs):
        """
        Parameters
        ----------
        ray_trafo : :class:`odl.tomo.RayTransform`
            Ray transform from which the FBP operator is constructed.
        niter : int, optional
            Number of iteration blocks
        """
        super().__init__(ray_trafo, **kwargs)

        # NOTE: self.ray_trafo is possibly normalized, while ray_trafo is not
        self.non_normed_ray_trafo = ray_trafo

        if niter is not None:
            self.niter = niter
            if kwargs.get('hyper_params', {}).get('niter') is not None:
                warn("hyper parameter 'niter' overridden by constructor "
                     "argument")

        self.ray_trafo_mod = OperatorModule(self.ray_trafo)
        self.ray_trafo_adj_mod = OperatorModule(self.ray_trafo.adjoint)

        partial0 = odl.PartialDerivative(self.ray_trafo.domain, axis=0)
        partial1 = odl.PartialDerivative(self.ray_trafo.domain, axis=1)
        self.reg_mod = OperatorModule(partial0.adjoint * partial0 +
                                      partial1.adjoint * partial1)
Esempio n. 2
0
    def __init__(self, ray_trafo, **kwargs):
        """
        Parameters
        ----------
        ray_trafo : :class:`odl.tomo.RayTransform`
            Ray transform from which the FBP operator is constructed.
        """

        super().__init__(ray_trafo, **kwargs)

        # NOTE: self.ray_trafo is possibly normalized, while ray_trafo is not
        self.non_normed_ray_trafo = ray_trafo
        self.ray_trafo_mod = OperatorModule(self.ray_trafo)
        self.ray_trafo_adj_mod = OperatorModule(self.ray_trafo.adjoint)
Esempio n. 3
0
    def init_model(self):
        if self.hyper_params['init_fbp']:
            fbp = fbp_op(
                self.non_normed_ray_trafo,
                filter_type=self.hyper_params['init_filter_type'],
                frequency_scaling=self.hyper_params['init_frequency_scaling'])
            if self.normalize_by_opnorm:
                fbp = OperatorRightScalarMult(fbp, self.opnorm)
            self.init_mod = OperatorModule(fbp)
        else:
            self.init_mod = None
        self.model = PrimalDualNet(
            n_iter=self.niter,
            op=self.ray_trafo_mod,
            op_adj=self.ray_trafo_adj_mod,
            op_init=self.init_mod,
            n_primal=self.hyper_params['nprimal'],
            n_dual=self.hyper_params['ndual'],
            use_sigmoid=self.hyper_params['use_sigmoid'],
            internal_ch=self.hyper_params['internal_ch'],
            kernel_size=self.hyper_params['kernel_size'],
            batch_norm=self.hyper_params['batch_norm'],
            prelu=self.hyper_params['prelu'],
            lrelu_coeff=self.hyper_params['lrelu_coeff'])

        def weights_init(m):
            if isinstance(m, torch.nn.Conv2d):
                m.bias.data.fill_(0.0)
                torch.nn.init.xavier_uniform_(m.weight)

        self.model.apply(weights_init)

        if self.use_cuda:
            # WARNING: using data-parallel here doesn't work because of astra-gpu
            self.model = self.model.to(self.device)
Esempio n. 4
0
    def init_model(self):
        self.op_mod = OperatorModule(self.op)
        self.op_adj_mod = OperatorModule(self.op.adjoint)
        partial0 = odl.PartialDerivative(self.op.domain, axis=0)
        partial1 = odl.PartialDerivative(self.op.domain, axis=1)
        self.reg_mod = OperatorModule(partial0.adjoint * partial0 +
                                      partial1.adjoint * partial1)
        if self.hyper_params['init_fbp']:
            fbp = fbp_op(
                self.non_normed_op,
                filter_type=self.hyper_params['init_filter_type'],
                frequency_scaling=self.hyper_params['init_frequency_scaling'])
            if self.normalize_by_opnorm:
                fbp = OperatorRightScalarMult(fbp, self.opnorm)
            self.init_mod = OperatorModule(fbp)
        else:
            self.init_mod = None
        self.model = IterativeNet(n_iter=self.niter,
                                  n_memory=5,
                                  op=self.op_mod,
                                  op_adj=self.op_adj_mod,
                                  op_init=self.init_mod,
                                  op_reg=self.reg_mod,
                                  use_sigmoid=self.hyper_params['use_sigmoid'],
                                  n_layer=self.hyper_params['nlayer'],
                                  internal_ch=self.hyper_params['internal_ch'],
                                  kernel_size=self.hyper_params['kernel_size'],
                                  batch_norm=self.hyper_params['batch_norm'],
                                  prelu=self.hyper_params['prelu'],
                                  lrelu_coeff=self.hyper_params['lrelu_coeff'])

        def weights_init(m):
            if isinstance(m, torch.nn.Conv2d):
                m.bias.data.fill_(0.0)
                if self.hyper_params['init_weight_xavier_normal']:
                    torch.nn.init.xavier_normal_(
                        m.weight, gain=self.hyper_params['init_weight_gain'])

        self.model.apply(weights_init)

        if self.use_cuda:
            # WARNING: using data-parallel here doesn't work, probably
            # astra_cuda is not thread-safe
            self.model = self.model.to(self.device)
Esempio n. 5
0
    def __init__(self,
                 ray_trafo,
                 callback_func=None,
                 callback_func_interval=100,
                 show_pbar=True,
                 torch_manual_seed=10,
                 **kwargs):
        """
        Parameters
        ----------
        ray_trafo : `odl.tomo.operators.RayTransform`
            The forward operator
        callback_func : callable, optional
            Callable with signature
            ``callback_func(iteration, reconstruction, loss)`` that is called
            after every `callback_func_interval` iterations, starting
            after the first iteration. It is additionally called after the
            last iteration.
            Note that it differs from the inherited
            `IterativeReconstructor.callback` (which is also supported) in that
            the latter is of type :class:`odl.solvers.util.callback.Callback`,
            which only receives the reconstruction, such that the loss would
            have to be recomputed.
        callback_func_interval : int, optional
            Number of iterations between calls to `callback_func`.
            Default: `100`.
        show_pbar : bool, optional
            Whether to show a tqdm progress bar during reconstruction.
        torch_manual_seed : int, optional
            Fixed seed to set by ``torch.manual_seed`` before reconstruction.
            The default is `10`. It can be set to `None` or `False` to disable
            the manual seed.
        """

        super().__init__(reco_space=ray_trafo.domain,
                         observation_space=ray_trafo.range,
                         **kwargs)

        self.callback_func = callback_func
        self.ray_trafo = ray_trafo
        self.ray_trafo_module = OperatorModule(self.ray_trafo)
        self.callback_func = callback_func
        self.callback_func_interval = callback_func_interval
        self.show_pbar = show_pbar
        self.torch_manual_seed = torch_manual_seed
    def __init__(self, ray_trafo, ini_reco, hyper_params=None, callback=None, callback_func=None,
                 callback_func_interval=100, **kwargs):
        """
        Parameters
        ----------
        ray_trafo : `odl.tomo.operators.RayTransform`
            The forward operator
        ini_reco: `dival.Reconstructor`
            Reconstructor used for the initial reconstruction
        """
        super().__init__(
            reco_space=ray_trafo.domain, observation_space=ray_trafo.range,
            hyper_params=hyper_params, callback=callback, **kwargs)

        self.ray_trafo = ray_trafo
        self.ray_trafo_module = OperatorModule(self.ray_trafo)

        self.domain_shape = ray_trafo.domain.shape
        self.ini_reco = ini_reco
        self.callback_func = callback_func
        self.callback_func_interval = callback_func_interval
Esempio n. 7
0
    def __init__(self, ray_trafo, hyper_params=None, callback=None,
                 callback_func=None, callback_func_interval=100, **kwargs):
        """
        Parameters
        ----------
        ray_trafo : `odl.tomo.operators.RayTransform`
            The forward operator
        """

        super().__init__(
            reco_space=ray_trafo.domain, observation_space=ray_trafo.range,
            hyper_params=hyper_params, callback=callback, **kwargs)

        self.fbp_op = fbp_op(
            ray_trafo, frequency_scaling=0.1, filter_type='Hann')
        self.callback_func = callback_func
        self.ray_trafo = ray_trafo
        self.ray_trafo_module = OperatorModule(self.ray_trafo)
        self.domain_shape = ray_trafo.domain.shape
        self.callback_func = callback_func
        self.callback_func_interval = callback_func_interval
Esempio n. 8
0
 def __init__(self, ray_trafo, **kwargs):
     super().__init__(ray_trafo, **kwargs)
     self.ray_trafo = ray_trafo
     self.ray_trafo_module = OperatorModule(self.ray_trafo)
     self.init_model()