def __init__(self, ray_trafo, niter=None, **kwargs): """ Parameters ---------- ray_trafo : :class:`odl.tomo.RayTransform` Ray transform from which the FBP operator is constructed. niter : int, optional Number of iteration blocks """ super().__init__(ray_trafo, **kwargs) # NOTE: self.ray_trafo is possibly normalized, while ray_trafo is not self.non_normed_ray_trafo = ray_trafo if niter is not None: self.niter = niter if kwargs.get('hyper_params', {}).get('niter') is not None: warn("hyper parameter 'niter' overridden by constructor " "argument") self.ray_trafo_mod = OperatorModule(self.ray_trafo) self.ray_trafo_adj_mod = OperatorModule(self.ray_trafo.adjoint) partial0 = odl.PartialDerivative(self.ray_trafo.domain, axis=0) partial1 = odl.PartialDerivative(self.ray_trafo.domain, axis=1) self.reg_mod = OperatorModule(partial0.adjoint * partial0 + partial1.adjoint * partial1)
def __init__(self, ray_trafo, **kwargs): """ Parameters ---------- ray_trafo : :class:`odl.tomo.RayTransform` Ray transform from which the FBP operator is constructed. """ super().__init__(ray_trafo, **kwargs) # NOTE: self.ray_trafo is possibly normalized, while ray_trafo is not self.non_normed_ray_trafo = ray_trafo self.ray_trafo_mod = OperatorModule(self.ray_trafo) self.ray_trafo_adj_mod = OperatorModule(self.ray_trafo.adjoint)
def init_model(self): if self.hyper_params['init_fbp']: fbp = fbp_op( self.non_normed_ray_trafo, filter_type=self.hyper_params['init_filter_type'], frequency_scaling=self.hyper_params['init_frequency_scaling']) if self.normalize_by_opnorm: fbp = OperatorRightScalarMult(fbp, self.opnorm) self.init_mod = OperatorModule(fbp) else: self.init_mod = None self.model = PrimalDualNet( n_iter=self.niter, op=self.ray_trafo_mod, op_adj=self.ray_trafo_adj_mod, op_init=self.init_mod, n_primal=self.hyper_params['nprimal'], n_dual=self.hyper_params['ndual'], use_sigmoid=self.hyper_params['use_sigmoid'], internal_ch=self.hyper_params['internal_ch'], kernel_size=self.hyper_params['kernel_size'], batch_norm=self.hyper_params['batch_norm'], prelu=self.hyper_params['prelu'], lrelu_coeff=self.hyper_params['lrelu_coeff']) def weights_init(m): if isinstance(m, torch.nn.Conv2d): m.bias.data.fill_(0.0) torch.nn.init.xavier_uniform_(m.weight) self.model.apply(weights_init) if self.use_cuda: # WARNING: using data-parallel here doesn't work because of astra-gpu self.model = self.model.to(self.device)
def init_model(self): self.op_mod = OperatorModule(self.op) self.op_adj_mod = OperatorModule(self.op.adjoint) partial0 = odl.PartialDerivative(self.op.domain, axis=0) partial1 = odl.PartialDerivative(self.op.domain, axis=1) self.reg_mod = OperatorModule(partial0.adjoint * partial0 + partial1.adjoint * partial1) if self.hyper_params['init_fbp']: fbp = fbp_op( self.non_normed_op, filter_type=self.hyper_params['init_filter_type'], frequency_scaling=self.hyper_params['init_frequency_scaling']) if self.normalize_by_opnorm: fbp = OperatorRightScalarMult(fbp, self.opnorm) self.init_mod = OperatorModule(fbp) else: self.init_mod = None self.model = IterativeNet(n_iter=self.niter, n_memory=5, op=self.op_mod, op_adj=self.op_adj_mod, op_init=self.init_mod, op_reg=self.reg_mod, use_sigmoid=self.hyper_params['use_sigmoid'], n_layer=self.hyper_params['nlayer'], internal_ch=self.hyper_params['internal_ch'], kernel_size=self.hyper_params['kernel_size'], batch_norm=self.hyper_params['batch_norm'], prelu=self.hyper_params['prelu'], lrelu_coeff=self.hyper_params['lrelu_coeff']) def weights_init(m): if isinstance(m, torch.nn.Conv2d): m.bias.data.fill_(0.0) if self.hyper_params['init_weight_xavier_normal']: torch.nn.init.xavier_normal_( m.weight, gain=self.hyper_params['init_weight_gain']) self.model.apply(weights_init) if self.use_cuda: # WARNING: using data-parallel here doesn't work, probably # astra_cuda is not thread-safe self.model = self.model.to(self.device)
def __init__(self, ray_trafo, callback_func=None, callback_func_interval=100, show_pbar=True, torch_manual_seed=10, **kwargs): """ Parameters ---------- ray_trafo : `odl.tomo.operators.RayTransform` The forward operator callback_func : callable, optional Callable with signature ``callback_func(iteration, reconstruction, loss)`` that is called after every `callback_func_interval` iterations, starting after the first iteration. It is additionally called after the last iteration. Note that it differs from the inherited `IterativeReconstructor.callback` (which is also supported) in that the latter is of type :class:`odl.solvers.util.callback.Callback`, which only receives the reconstruction, such that the loss would have to be recomputed. callback_func_interval : int, optional Number of iterations between calls to `callback_func`. Default: `100`. show_pbar : bool, optional Whether to show a tqdm progress bar during reconstruction. torch_manual_seed : int, optional Fixed seed to set by ``torch.manual_seed`` before reconstruction. The default is `10`. It can be set to `None` or `False` to disable the manual seed. """ super().__init__(reco_space=ray_trafo.domain, observation_space=ray_trafo.range, **kwargs) self.callback_func = callback_func self.ray_trafo = ray_trafo self.ray_trafo_module = OperatorModule(self.ray_trafo) self.callback_func = callback_func self.callback_func_interval = callback_func_interval self.show_pbar = show_pbar self.torch_manual_seed = torch_manual_seed
def __init__(self, ray_trafo, ini_reco, hyper_params=None, callback=None, callback_func=None, callback_func_interval=100, **kwargs): """ Parameters ---------- ray_trafo : `odl.tomo.operators.RayTransform` The forward operator ini_reco: `dival.Reconstructor` Reconstructor used for the initial reconstruction """ super().__init__( reco_space=ray_trafo.domain, observation_space=ray_trafo.range, hyper_params=hyper_params, callback=callback, **kwargs) self.ray_trafo = ray_trafo self.ray_trafo_module = OperatorModule(self.ray_trafo) self.domain_shape = ray_trafo.domain.shape self.ini_reco = ini_reco self.callback_func = callback_func self.callback_func_interval = callback_func_interval
def __init__(self, ray_trafo, hyper_params=None, callback=None, callback_func=None, callback_func_interval=100, **kwargs): """ Parameters ---------- ray_trafo : `odl.tomo.operators.RayTransform` The forward operator """ super().__init__( reco_space=ray_trafo.domain, observation_space=ray_trafo.range, hyper_params=hyper_params, callback=callback, **kwargs) self.fbp_op = fbp_op( ray_trafo, frequency_scaling=0.1, filter_type='Hann') self.callback_func = callback_func self.ray_trafo = ray_trafo self.ray_trafo_module = OperatorModule(self.ray_trafo) self.domain_shape = ray_trafo.domain.shape self.callback_func = callback_func self.callback_func_interval = callback_func_interval
def __init__(self, ray_trafo, **kwargs): super().__init__(ray_trafo, **kwargs) self.ray_trafo = ray_trafo self.ray_trafo_module = OperatorModule(self.ray_trafo) self.init_model()