def init_model(self): self.fbp_op = fbp_op(self.op, filter_type=self.filter_type, frequency_scaling=self.frequency_scaling) self.model = UNet(in_ch=1, out_ch=1, channels=self.channels[:self.scales], skip_channels=[self.skip_channels] * (self.scales), use_sigmoid=self.use_sigmoid) # UNet(num_input_channels=1, num_output_channels=1, # feature_scale=4, more_layers=0, concat_x=False, # upsample_mode='bilinear', norm_layer=torch.nn.BatchNorm2d, # pad='reflect', # need_sigmoid=False, need_bias=True).to('cuda') self.masker = Masker(width=4, mode='interpolate') if self.init_bias_zero: def weights_init(m): if isinstance(m, torch.nn.Conv2d): m.bias.data.fill_(0.0) self.model.apply(weights_init) if self.use_cuda: self.model = nn.DataParallel(self.model).to(self.device)
def __init__(self, ray_trafo, scales=None, epochs=None, batch_size=None, num_data_loader_workers=8, use_cuda=True, show_pbar=True, fbp_impl='astra_cuda', **kwargs): """ Parameters ---------- ray_trafo : :class:`odl.tomo.RayTransform` Ray transform from which the FBP operator is constructed. scales : int, optional Number of scales in the U-Net (a hyper parameter). epochs : int, optional Number of epochs to train (a hyper parameter). batch_size : int, optional Batch size (a hyper parameter). num_data_loader_workers : int, optional Number of parallel workers to use for loading data. use_cuda : bool, optional Whether to use cuda for the U-Net. show_pbar : bool, optional Whether to show tqdm progress bars during the epochs. fbp_impl : str, optional The backend implementation passed to :class:`odl.tomo.RayTransform` in case no `ray_trafo` is specified. Then ``dataset.get_ray_trafo(impl=fbp_impl)`` is used to get the ray transform and FBP operator. """ self.ray_trafo = ray_trafo self.fbp_op = fbp_op(self.ray_trafo) self.num_data_loader_workers = num_data_loader_workers self.use_cuda = use_cuda self.show_pbar = show_pbar self.fbp_impl = fbp_impl super().__init__(reco_space=self.ray_trafo.domain, observation_space=self.ray_trafo.range, **kwargs) if epochs is not None: self.epochs = epochs if kwargs.get('hyper_params', {}).get('epochs') is not None: warn( "hyper parameter 'epochs' overridden by constructor argument" ) if batch_size is not None: self.batch_size = batch_size if kwargs.get('hyper_params', {}).get('batch_size') is not None: warn( "hyper parameter 'batch_size' overridden by constructor argument" ) if scales is not None: self.scales = scales if kwargs.get('hyper_params', {}).get('scales') is not None: warn( "hyper parameter 'scales' overridden by constructor argument" )
def __init__(self, dataset, ray_trafo, filter_type='Hann', frequency_scaling=1.0): """ Parameters ---------- dataset : :class:`.Dataset` CT dataset. FBPs are computed from the observations, the ground truth is taken directly from the dataset. ray_trafo : :class:`odl.tomo.RayTransform` Ray transform from which the FBP operator is constructed. filter_type : str, optional Filter type accepted by :func:`odl.tomo.fbp_op`. Default: ``'Hann'``. frequency_scaling : float, optional Relative cutoff frequency passed to :func:`odl.tomo.fbp_op`. Default: ``1.0``. """ self.dataset = dataset self.ray_trafo = ray_trafo self.fbp_op = fbp_op(self.ray_trafo, filter_type=filter_type, frequency_scaling=frequency_scaling) self.train_len = self.dataset.get_len('train') self.validation_len = self.dataset.get_len('validation') self.test_len = self.dataset.get_len('test') self.shape = (self.dataset.shape[1], self.dataset.shape[1]) self.num_elements_per_sample = 2 self.random_access = dataset.supports_random_access() super().__init__(space=(self.dataset.space[1], self.dataset.space[1]))
def generate_fbp_cache(dataset, part, filename, ray_trafo, size=None): """ Write filtered back-projections for a CT dataset part to file. Parameters ---------- dataset : :class:`.Dataset` CT dataset from which the observations are used. part : {``'train'``, ``'validation'``, ``'test'``} The data part. filename : str The filename to store the FBP cache at (ending ``.npy``). ray_trafo : :class:`odl.tomo.RayTransform` Ray transform from which the FBP operator is constructed. size : int, optional Number of samples to use from the dataset. By default, all samples are used. """ fbp = fbp_op(ray_trafo) num_samples = dataset.get_len(part=part) if size is None else size reco_fbps = np.empty((num_samples, ) + dataset.shape[1], dtype=np.float32) obs = np.empty(dataset.shape[0], dtype=np.float32) tmp_fbp = fbp.range.element() for i in tqdm(range(num_samples), desc='generating FBP cache'): dataset.get_sample(i, part=part, out=(obs, False)) fbp(obs, out=tmp_fbp) reco_fbps[i][:] = tmp_fbp np.save(filename, reco_fbps)
def __init__(self, ray_trafo, hyper_params=None, iterations=None, gamma=None, **kwargs): """ Parameters ---------- ray_trafo : `odl.tomo.operators.RayTransform` The forward operator """ super().__init__( reco_space=ray_trafo.domain, observation_space=ray_trafo.range, hyper_params=hyper_params, **kwargs) self.ray_trafo = ray_trafo self.domain_shape = ray_trafo.domain.shape self.opnorm = odl.power_method_opnorm(ray_trafo) self.fbp_op = fbp_op( ray_trafo, frequency_scaling=0.1, filter_type='Hann') if iterations is not None: self.iterations = iterations if kwargs.get('hyper_params', {}).get('iterations') is not None: warn("hyper parameter 'iterations' overridden by constructor argument") if gamma is not None: self.gamma = gamma if kwargs.get('hyper_params', {}).get('gamma') is not None: warn("hyper parameter 'gamma' overridden by constructor argument")
def init_model(self): if self.hyper_params['init_fbp']: fbp = fbp_op( self.non_normed_ray_trafo, filter_type=self.hyper_params['init_filter_type'], frequency_scaling=self.hyper_params['init_frequency_scaling']) if self.normalize_by_opnorm: fbp = OperatorRightScalarMult(fbp, self.opnorm) self.init_mod = OperatorModule(fbp) else: self.init_mod = None self.model = PrimalDualNet( n_iter=self.niter, op=self.ray_trafo_mod, op_adj=self.ray_trafo_adj_mod, op_init=self.init_mod, n_primal=self.hyper_params['nprimal'], n_dual=self.hyper_params['ndual'], use_sigmoid=self.hyper_params['use_sigmoid'], internal_ch=self.hyper_params['internal_ch'], kernel_size=self.hyper_params['kernel_size'], batch_norm=self.hyper_params['batch_norm'], prelu=self.hyper_params['prelu'], lrelu_coeff=self.hyper_params['lrelu_coeff']) def weights_init(m): if isinstance(m, torch.nn.Conv2d): m.bias.data.fill_(0.0) torch.nn.init.xavier_uniform_(m.weight) self.model.apply(weights_init) if self.use_cuda: # WARNING: using data-parallel here doesn't work because of astra-gpu self.model = self.model.to(self.device)
def _reconstruct(self, observation, out): if self.pre_processor is not None: observation = self.pre_processor(observation) if self.recompute_fbp_op: self.fbp_op = fbp_op(self.ray_trafo, padding=self.padding, **self.hyper_params) self.fbp_op(observation, out=out) if self.post_processor is not None: out[:] = self.post_processor(out)
def _reconstruct(self, observation, *args, **kwargs): self.fbp_op = fbp_op(self.ray_trafo, filter_type=self.init_filter_type, frequency_scaling=self.init_frequency_scaling) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.output = torch.tensor(self.fbp_op(observation))[None].to(device) self.output.requires_grad = True self.optimizer = Adam([self.output], lr=self.lr) y_delta = torch.tensor(np.asarray(observation), dtype=torch.float32) y_delta = y_delta.view(1, *y_delta.shape) y_delta = y_delta.to(device) if self.loss_function == 'mse': criterion = MSELoss() elif self.loss_function == 'poisson': criterion = partial(poisson_loss, photons_per_pixel=self.photons_per_pixel, mu_max=self.mu_max) else: warn('Unknown loss function, falling back to MSE') criterion = MSELoss() best_loss = np.infty best_output = self.output.detach().clone() for i in tqdm(range(self.iterations), desc='TV', disable=not self.show_pbar): self.optimizer.zero_grad() loss = criterion(self.ray_trafo_module(self.output), y_delta) + self.gamma * tv_loss(self.output) loss.backward() self.optimizer.step() if loss.item() < best_loss: best_loss = loss.item() best_output = self.output.detach().clone() if (self.callback_func is not None and (i % self.callback_func_interval == 0 or i == self.iterations - 1)): self.callback_func( iteration=i, reconstruction=best_output[0, ...].cpu().numpy(), loss=best_loss) if self.callback is not None: self.callback( self.reco_space.element(best_output[0, ...].cpu().numpy())) return self.reco_space.element(best_output[0, ...].cpu().numpy())
def generate_dataset_cache(dataset, part, file_names, ray_trafo, filter_type='Hann', frequency_scaling=1.0, size=None, only_fbp=False): """ Write data-paris for a CT dataset part to file. Parameters ---------- dataset : :class:`.Dataset` CT dataset from which the observations are used. part : {``'train'``, ``'validation'``, ``'test'``} The data part. file_name : str The filenames to store the cache at (ending ``.npy``). ray_trafo : :class:`odl.tomo.RayTransform` Ray transform from which the FBP operator is constructed. size : int, optional Number of samples to use from the dataset. By default, all samples are used. """ fbp = fbp_op(ray_trafo, filter_type=filter_type, frequency_scaling=frequency_scaling) num_samples = dataset.get_len(part=part) if size is None else size if not only_fbp: gts_data = np.empty((num_samples, ) + ray_trafo.domain.shape, dtype=np.float32) obs_data = np.empty((num_samples, ) + ray_trafo.range.shape, dtype=np.float32) fbp_data = np.empty((num_samples, ) + ray_trafo.domain.shape, dtype=np.float32) tmp_fbp = fbp.range.element() for i, (obs, gt) in zip(tqdm(range(num_samples), desc='generating cache'), dataset.generator(part)): fbp(obs, out=tmp_fbp) fbp_data[i, ...] = tmp_fbp if not only_fbp: obs_data[i, ...] = obs gts_data[i, ...] = gt if only_fbp: np.save(file_names[0], fbp_data) else: for filename, data in zip(file_names, [gts_data, obs_data, fbp_data]): np.save(filename, data)
def _reconstruct(self, observation, out): if self.pre_processor is not None: observation = self.pre_processor(observation) if self.recompute_fbp_op: self.fbp_op = fbp_op(self.ray_trafo, padding=self.padding, **self.hyper_params) if out in self.reco_space: self.fbp_op(observation, out=out) else: # out is e.g. numpy array, cannot be passed to fbp_op out[:] = self.fbp_op(observation) if self.post_processor is not None: out[:] = self.post_processor(out)
def init_model(self): self.fbp_op = fbp_op(self.op, filter_type=self.filter_type, frequency_scaling=self.frequency_scaling) self.model = UNet(in_ch=1, out_ch=1, channels=self.channels[:self.scales], skip_channels=[self.skip_channels] * (self.scales), use_sigmoid=self.use_sigmoid) if self.init_bias_zero: def weights_init(m): if isinstance(m, torch.nn.Conv2d): m.bias.data.fill_(0.0) self.model.apply(weights_init) if self.use_cuda: self.model = nn.DataParallel(self.model).to(self.device)
def init_model(self): self.op_mod = OperatorModule(self.op) self.op_adj_mod = OperatorModule(self.op.adjoint) partial0 = odl.PartialDerivative(self.op.domain, axis=0) partial1 = odl.PartialDerivative(self.op.domain, axis=1) self.reg_mod = OperatorModule(partial0.adjoint * partial0 + partial1.adjoint * partial1) if self.hyper_params['init_fbp']: fbp = fbp_op( self.non_normed_op, filter_type=self.hyper_params['init_filter_type'], frequency_scaling=self.hyper_params['init_frequency_scaling']) if self.normalize_by_opnorm: fbp = OperatorRightScalarMult(fbp, self.opnorm) self.init_mod = OperatorModule(fbp) else: self.init_mod = None self.model = IterativeNet(n_iter=self.niter, n_memory=5, op=self.op_mod, op_adj=self.op_adj_mod, op_init=self.init_mod, op_reg=self.reg_mod, use_sigmoid=self.hyper_params['use_sigmoid'], n_layer=self.hyper_params['nlayer'], internal_ch=self.hyper_params['internal_ch'], kernel_size=self.hyper_params['kernel_size'], batch_norm=self.hyper_params['batch_norm'], prelu=self.hyper_params['prelu'], lrelu_coeff=self.hyper_params['lrelu_coeff']) def weights_init(m): if isinstance(m, torch.nn.Conv2d): m.bias.data.fill_(0.0) if self.hyper_params['init_weight_xavier_normal']: torch.nn.init.xavier_normal_( m.weight, gain=self.hyper_params['init_weight_gain']) self.model.apply(weights_init) if self.use_cuda: # WARNING: using data-parallel here doesn't work, probably # astra_cuda is not thread-safe self.model = self.model.to(self.device)
def __init__(self, dataset, ray_trafo): """ Parameters ---------- dataset : :class:`.Dataset` CT dataset. FBPs are computed from the observations, the ground truth is taken directly from the dataset. ray_trafo : :class:`odl.tomo.RayTransform` Ray transform from which the FBP operator is constructed. """ self.dataset = dataset self.fbp_op = fbp_op(ray_trafo) self.train_len = self.dataset.get_len('train') self.validation_len = self.dataset.get_len('validation') self.test_len = self.dataset.get_len('test') self.shape = (self.dataset.shape[1], self.dataset.shape[1]) self.num_elements_per_sample = 2 self.random_access = dataset.supports_random_access() super().__init__(space=(self.dataset.space[1], self.dataset.space[1]))
def init_model(self): self.fbp_op = fbp_op(self.ray_trafo, filter_type=self.filter_type, frequency_scaling=self.frequency_scaling) self.model = get_unet_model(scales=self.scales, skip=self.skip_channels, channels=self.channels, use_sigmoid=self.use_sigmoid) if self.init_bias_zero: def weights_init(m): if isinstance(m, torch.nn.Conv2d): m.bias.data.fill_(0.0) self.model.apply(weights_init) if self.use_cuda: self.model = nn.DataParallel(self.model).to(self.device)
def __init__(self, ray_trafo, padding=True, hyper_params=None, pre_processor=None, post_processor=None, recompute_fbp_op=True, **kwargs): """ Parameters ---------- ray_trafo : `odl.tomo.operators.RayTransform` The forward operator. See `odl.tomo.fbp_op` for details. padding : bool, optional Whether to use padding (the default is ``True``). See `odl.tomo.fbp_op` for details. pre_processor : callable, optional Callable that takes the observation and returns the sinogram that is passed to the filtered back-projection operator. post_processor : callable, optional Callable that takes the filtered back-projection and returns the final reconstruction. recompute_fbp_op : bool, optional Whether :attr:`fbp_op` should be recomputed on each call to :meth:`reconstruct`. Must be ``True`` (default) if changes to :attr:`ray_trafo`, :attr:`hyper_params` or :attr:`padding` are planned in order to use the updated values in :meth:`reconstruct`. If none of these attributes will change, you may specify ``recompute_fbp_op==False``, so :attr:`fbp_op` can be computed only once, improving reconstruction time efficiency. """ self.ray_trafo = ray_trafo self.padding = padding self.pre_processor = pre_processor self.post_processor = post_processor super().__init__(reco_space=ray_trafo.domain, observation_space=ray_trafo.range, hyper_params=hyper_params, **kwargs) self.fbp_op = fbp_op(self.ray_trafo, padding=self.padding, **self.hyper_params) self.recompute_fbp_op = recompute_fbp_op
def __init__(self, ray_trafo, hyper_params=None, callback=None, callback_func=None, callback_func_interval=100, **kwargs): """ Parameters ---------- ray_trafo : `odl.tomo.operators.RayTransform` The forward operator """ super().__init__( reco_space=ray_trafo.domain, observation_space=ray_trafo.range, hyper_params=hyper_params, callback=callback, **kwargs) self.fbp_op = fbp_op( ray_trafo, frequency_scaling=0.1, filter_type='Hann') self.callback_func = callback_func self.ray_trafo = ray_trafo self.ray_trafo_module = OperatorModule(self.ray_trafo) self.domain_shape = ray_trafo.domain.shape self.callback_func = callback_func self.callback_func_interval = callback_func_interval
def __init__(self, ray_trafo, filter_type=None, frequency_scaling=None, scales=None, epochs=None, batch_size=None, lr=None, skip_channels=None, num_data_loader_workers=8, use_cuda=True, show_pbar=True, fbp_impl='astra_cuda', hyper_params=None, **kwargs): """ Parameters ---------- ray_trafo : :class:`odl.tomo.RayTransform` Ray transform from which the FBP operator is constructed. scales : int, optional Number of scales in the U-Net (a hyper parameter). epochs : int, optional Number of epochs to train (a hyper parameter). batch_size : int, optional Batch size (a hyper parameter). num_data_loader_workers : int, optional Number of parallel workers to use for loading data. use_cuda : bool, optional Whether to use cuda for the U-Net. show_pbar : bool, optional Whether to show tqdm progress bars during the epochs. fbp_impl : str, optional The backend implementation passed to :class:`odl.tomo.RayTransform` in case no `ray_trafo` is specified. Then ``dataset.get_ray_trafo(impl=fbp_impl)`` is used to get the ray transform and FBP operator. """ super().__init__(ray_trafo, epochs=epochs, batch_size=batch_size, lr=lr, num_data_loader_workers=num_data_loader_workers, use_cuda=use_cuda, show_pbar=show_pbar, fbp_impl=fbp_impl, hyper_params=hyper_params, **kwargs) if scales is not None: self.scales = scales if kwargs.get('hyper_params', {}).get('scales') is not None: warn( "hyper parameter 'scales' overridden by constructor argument" ) if skip_channels is not None: self.skip_channels = skip_channels if kwargs.get('hyper_params', {}).get('skip_channels') is not None: warn( "hyper parameter 'skip_channels' overridden by constructor argument" ) if filter_type is not None: self.filter_type = filter_type if kwargs.get('hyper_params', {}).get('filter_type') is not None: warn( "hyper parameter 'filter_type' overridden by constructor argument" ) if frequency_scaling is not None: self.frequency_scaling = frequency_scaling if kwargs.get('hyper_params', {}).get('frequency_scaling') is not None: warn( "hyper parameter 'frequency_scaling' overridden by constructor argument" ) # TODO: update fbp_op when the hyper parameters change? self.fbp_op = fbp_op(ray_trafo, filter_type=self.filter_type, frequency_scaling=self.frequency_scaling)