def __init__(self, data, fourier_op): GradBasic.__init__(self, data, fourier_op.op, fourier_op.adj_op) self.fourier_op = fourier_op PowerMethod.__init__(self, self.trans_op_op, self.fourier_op.shape, data_type=np.complex, auto_run=False) self.get_spec_rad(extra_factor=1.1)
def __init__(self, data, linear_op, fourier_op): GradBasic.__init__(self, data, self._op_method, self._trans_op_method) self.fourier_op = fourier_op self.linear_op = linear_op coef = linear_op.op(np.zeros(fourier_op.shape).astype(np.complex)) PowerMethod.__init__(self, self.trans_op_op, coef.shape, data_type=np.complex, auto_run=False) self.get_spec_rad(extra_factor=1.1)
def __init__(self, data, fourier_op, S): """ Initilize the 'GradSynthesis2' class. """ self.fourier_op = fourier_op self.S = S GradBasic.__init__(self, data, self._analy_op_method, self._analy_rsns_op_method) PowerMethod.__init__(self, self.trans_op_op, self.fourier_op.shape, data_type="complex128", auto_run=False) self.get_spec_rad(extra_factor=1.1)
def set_objective(self, X, y, lmbd): self.X, self.y, self.lmbd = X, y, lmbd n_features = self.X.shape[1] if self.restart_strategy == 'greedy': min_beta = 1.0 s_greedy = 1.1 p_lazy = 1.0 q_lazy = 1.0 else: min_beta = None s_greedy = None p_lazy = 1 / 30 q_lazy = 1 / 10 self.fb = ForwardBackward( x=np.zeros(n_features), # this is the coefficient w grad=GradBasic( input_data=y, op=lambda w: self.X@w, trans_op=lambda res: self.X.T@res, ), prox=SparseThreshold(Identity(), lmbd), beta_param=1.0, min_beta=min_beta, metric_call_period=None, restart_strategy=self.restart_strategy, xi_restart=0.96, s_greedy=s_greedy, p_lazy=p_lazy, q_lazy=q_lazy, auto_iterate=False, progress=False, cost=None, )
def __init__(self, data, fourier_op, linear_op, S): """ Initilize the 'GradSynthesis2' class. """ self.fourier_op = fourier_op self.linear_op = linear_op self.S = S GradBasic.__init__(self, data, self._synth_op_method, self._synth_trans_op_method) coef = linear_op.op(np.zeros(fourier_op.shape).astype(np.complex)) self.linear_op_coeffs_shape = coef.shape PowerMethod.__init__(self, self.trans_op_op, coef.shape, data_type="complex128", auto_run=False) self.get_spec_rad(extra_factor=1.1)
def set_objective(self, X, y, lmbd): self.X, self.y, self.lmbd = X, y, lmbd n_features = self.X.shape[1] sigma_bar = 0.96 var_init = np.zeros(n_features) self.pogm = POGM( x=var_init, # this is the coefficient w u=var_init, y=var_init, z=var_init, grad=GradBasic( op=lambda w: self.X @ w, trans_op=lambda res: self.X.T @ res, data=y, ), prox=SparseThreshold(Identity(), lmbd), beta_param=1.0, metric_call_period=None, sigma_bar=sigma_bar, auto_iterate=False, progress=False, cost=None, )
def sparse_deconv_condatvu(data, psf, n_iter=300, n_reweights=1): """Sparse Deconvolution with Condat-Vu Parameters ---------- data : np.ndarray Input data, 2D image psf : np.ndarray Input PSF, 2D image n_iter : int, optional Maximum number of iterations n_reweights : int, optional Number of reweightings Returns ------- np.ndarray deconvolved image """ # Print the algorithm set-up print(condatvu_logo()) # Define the wavelet filters filters = (get_cospy_filters( data.shape, transform_name='LinearWaveletTransformATrousAlgorithm')) # Set the reweighting scheme reweight = cwbReweight(get_weights(data, psf, filters)) # Set the initial variable values primal = np.ones(data.shape) dual = np.ones(filters.shape) # Set the gradient operators grad_op = GradBasic(data, lambda x: psf_convolve(x, psf), lambda x: psf_convolve(x, psf, psf_rot=True)) # Set the linear operator linear_op = WaveletConvolve2(filters) # Set the proximity operators prox_op = Positivity() prox_dual_op = SparseThreshold(linear_op, reweight.weights) # Set the cost function cost_op = costObj([grad_op, prox_op, prox_dual_op], tolerance=1e-6, cost_interval=1, plot_output=True, verbose=False) # Set the optimisation algorithm alg = Condat(primal, dual, grad_op, prox_op, prox_dual_op, linear_op, cost_op, rho=0.8, sigma=0.5, tau=0.5, auto_iterate=False) # Run the algorithm alg.iterate(max_iter=n_iter) # Implement reweigting for rw_num in range(n_reweights): print(' - Reweighting: {}'.format(rw_num + 1)) reweight.reweight(linear_op.op(alg.x_final)) alg.iterate(max_iter=n_iter) # Return the final result return alg.x_final