def setUp(self): """Set test parameter values.""" self.data1 = np.arange(9).reshape(3, 3).astype(float) + 1 self.data2 = np.array( [[0.5, 1.0, 1.5], [2.0, 2.5, 3.0], [3.5, 4.0, 4.5]], ) self.rw = reweight.cwbReweight(self.data1) self.rw.reweight(self.data1)
def set_sparse_weights(data_shape, psf, **kwargs): """Set the sparsity weights This method defines the weights for thresholding in the sparse domain and add them to the keyword arguments. It additionally defines the shape of the dual variable. Parameters ---------- data_shape : tuple Shape of the input data array psf : np.ndarray PSF data (2D or 3D array) Returns ------- dict Updated keyword arguments """ # Convolve the PSF with the wavelet filters if kwargs['psf_type'] == 'fixed': filter_conv = (filter_convolve(np.rot90(psf, 2), kwargs['wavelet_filters'], method=kwargs['convolve_method'])) filter_norm = np.array([ norm(a) * b * np.ones(data_shape[1:]) for a, b in zip(filter_conv, kwargs['wave_thresh_factor']) ]) filter_norm = np.array([filter_norm for i in range(data_shape[0])]) else: filter_conv = (filter_convolve_stack(np.rot90(psf, 2), kwargs['wavelet_filters'], method=kwargs['convolve_method'])) filter_norm = np.array([[ norm(b) * c * np.ones(data_shape[1:]) for b, c in zip(a, kwargs['wave_thresh_factor']) ] for a in filter_conv]) # Define a reweighting instance kwargs['reweight'] = cwbReweight(kwargs['noise_est'] * filter_norm) # Set the shape of the dual variable dual_shape = ([kwargs['wavelet_filters'].shape[0]] + list(data_shape)) dual_shape[0], dual_shape[1] = dual_shape[1], dual_shape[0] kwargs['dual_shape'] = dual_shape return kwargs
def set_sparse_weights(data_shape, psf, **kwargs): """Set the sparsity weights This method defines the weights for thresholding in the sparse domain and add them to the keyword arguments. It additionally defines the shape of the dual variable. Parameters ---------- data_shape : tuple Shape of the input data array psf : np.ndarray PSF data (2D or 3D array) Returns ------- dict Updated keyword arguments """ # Convolve the PSF with the wavelet filters if kwargs['psf_type'] == 'fixed': filter_conv = (filter_convolve(np.rot90(psf, 2), kwargs['wavelet_filters'], method=kwargs['convolve_method'])) filter_norm = np.array([norm(a) * b * np.ones(data_shape[1:]) for a, b in zip(filter_conv, kwargs['wave_thresh_factor'])]) filter_norm = np.array([filter_norm for i in range(data_shape[0])]) else: filter_conv = (filter_convolve_stack(np.rot90(psf, 2), kwargs['wavelet_filters'], method=kwargs['convolve_method'])) filter_norm = np.array([[norm(b) * c * np.ones(data_shape[1:]) for b, c in zip(a, kwargs['wave_thresh_factor'])] for a in filter_conv]) # Define a reweighting instance kwargs['reweight'] = cwbReweight(kwargs['noise_est'] * filter_norm) # Set the shape of the dual variable dual_shape = ([kwargs['wavelet_filters'].shape[0]] + list(data_shape)) dual_shape[0], dual_shape[1] = dual_shape[1], dual_shape[0] kwargs['dual_shape'] = dual_shape return kwargs
def condatvu(gradient_op, linear_op, dual_regularizer, cost_op, max_nb_of_iter=150, tau=None, sigma=None, relaxation_factor=1.0, x_init=None, std_est=None, std_est_method=None, std_thr=2., nb_of_reweights=1, metric_call_period=5, metrics={}, verbose=0): """ The Condat-Vu sparse reconstruction with reweightings. Parameters ---------- gradient_op: instance of class GradBase the gradient operator. linear_op: instance of LinearBase the linear operator: seek the sparsity, ie. a wavelet transform. dual_regularizer: instance of ProximityParent the dual regularization operator cost_op: instance of costObj the cost function used to check for convergence during the optimization. max_nb_of_iter: int, default 150 the maximum number of iterations in the Condat-Vu proximal-dual splitting algorithm. tau, sigma: float, default None parameters of the Condat-Vu proximal-dual splitting algorithm. If None estimates these parameters. relaxation_factor: float, default 0.5 parameter of the Condat-Vu proximal-dual splitting algorithm. If 1, no relaxation. x_init: np.ndarray (optional, default None) the initial guess of image std_est: float, default None the noise std estimate. If None use the MAD as a consistent estimator for the std. std_est_method: str, default None if the standard deviation is not set, estimate this parameter using the mad routine in the image ('primal') or in the sparse wavelet decomposition ('dual') domain. std_thr: float, default 2. use this treshold expressed as a number of sigma in the residual proximity operator during the thresholding. relaxation_factor: float, default 0.5 parameter of the Condat-Vu proximal-dual splitting algorithm. If 1, no relaxation. nb_of_reweights: int, default 1 the number of reweightings. metric_call_period: int (default 5) the period on which the metrics are compute. metrics: dict (optional, default None) the list of desired convergence metrics: {'metric_name': [@metric, metric_parameter]}. See modopt for the metrics API. verbose: int, default 0 the verbosity level. Returns ------- x_final: ndarray the estimated CONDAT-VU solution. costs: list of float the cost function values. metrics: dict the requested metrics values during the optimization. y_final: ndarrat the estimated dual CONDAT-VU solution """ # Check inputs start = time.clock() if std_est_method not in (None, "primal", "dual"): raise ValueError( "Unrecognize std estimation method '{0}'.".format(std_est_method)) # Define the initial primal and dual solutions if x_init is None: x_init = np.squeeze( np.zeros((linear_op.n_coils, *gradient_op.fourier_op.shape), dtype=np.complex)) primal = x_init dual = linear_op.op(primal) weights = dual # Define the weights used during the thresholding in the dual domain, # the reweighting strategy, and the prox dual operator # Case1: estimate the noise std in the image domain if std_est_method == "primal": if std_est is None: std_est = sigma_mad(gradient_op.MtX(gradient_op.obs_data)) weights[...] = std_thr * std_est reweight_op = cwbReweight(weights) dual_regularizer.weights = reweight_op.weights # Case2: estimate the noise std in the sparse wavelet domain elif std_est_method == "dual": if std_est is None: std_est = 0.0 weights[...] = std_thr * std_est reweight_op = mReweight(weights, linear_op, thresh_factor=std_thr) dual_regularizer.weights = reweight_op.weights # Case3: manual regularization mode, no reweighting else: reweight_op = None nb_of_reweights = 0 # Define the Condat Vu optimizer: define the tau and sigma in the # Condat-Vu proximal-dual splitting algorithm if not already provided. # Check also that the combination of values will lead to convergence. norm = linear_op.l2norm(x_init.shape) lipschitz_cst = gradient_op.spec_rad if sigma is None: sigma = 0.5 if tau is None: # to avoid numerics troubles with the convergence bound eps = 1.0e-8 # due to the convergence bound tau = 1.0 / (lipschitz_cst / 2 + sigma * norm**2 + eps) convergence_test = (1.0 / tau - sigma * norm**2 >= lipschitz_cst / 2.0) # Welcome message if verbose > 0: print(" - mu: ", dual_regularizer.weights) print(" - lipschitz constant: ", gradient_op.spec_rad) print(" - tau: ", tau) print(" - sigma: ", sigma) print(" - rho: ", relaxation_factor) print(" - std: ", std_est) print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test) print(" - data: ", gradient_op.fourier_op.shape) if hasattr(linear_op, "nb_scale"): print(" - wavelet: ", linear_op, "-", linear_op.nb_scale) print(" - max iterations: ", max_nb_of_iter) print(" - number of reweights: ", nb_of_reweights) print(" - primal variable shape: ", primal.shape) print(" - dual variable shape: ", dual.shape) print("-" * 40) prox_op = Identity() # Define the optimizer opt = Condat(x=primal, y=dual, grad=gradient_op, prox=prox_op, prox_dual=dual_regularizer, linear=linear_op, cost=cost_op, rho=relaxation_factor, sigma=sigma, tau=tau, rho_update=None, sigma_update=None, tau_update=None, auto_iterate=False, metric_call_period=metric_call_period, metrics=metrics) cost_op = opt._cost_func # Perform the first reconstruction if verbose > 0: print("Starting optimization...") opt.iterate(max_iter=max_nb_of_iter) # Loop through the number of reweightings for reweight_index in range(nb_of_reweights): # Generate the new weights following reweighting prescription if std_est_method == "primal": reweight_op.reweight(linear_op.op(opt._x_new)) else: std_est = reweight_op.reweight(opt._x_new) # Welcome message if verbose > 0: print(" - reweight: ", reweight_index + 1) print(" - std: ", std_est) # Update the weights in the dual proximity operator dual_regularizer.weights = reweight_op.weights # Perform optimisation with new weights opt.iterate(max_iter=max_nb_of_iter) # Goodbye message end = time.clock() if verbose > 0: if hasattr(cost_op, "cost"): print(" - final iteration number: ", cost_op._iteration) print(" - final cost value: ", cost_op.cost) print(" - converged: ", opt.converge) print("Done.") print("Execution time: ", end - start, " seconds") print("-" * 40) # Get the final solution x_final = opt.x_final y_final = opt.y_final if hasattr(cost_op, "cost"): costs = cost_op._cost_list else: costs = None return x_final, costs, opt.metrics, y_final
def setUp(self): """Set test parameter values.""" self.data1 = np.arange(9).reshape(3, 3).astype(float) self.data2 = self.data1 + np.random.randn(*self.data1.shape) * 1e-6 self.data3 = np.arange(9).reshape(3, 3).astype(float) + 1 grad_inst = gradient.GradBasic( self.data1, func_identity, func_identity, ) prox_inst = proximity.Positivity() prox_dual_inst = proximity.IdentityProx() linear_inst = linear.Identity() reweight_inst = reweight.cwbReweight(self.data3) cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) self.setup = algorithms.SetUp() self.max_iter = 20 self.fb_all_iter = algorithms.ForwardBackward( self.data1, grad=grad_inst, prox=prox_inst, cost=None, auto_iterate=False, beta_update=func_identity, ) self.fb_all_iter.iterate(self.max_iter) self.fb1 = algorithms.ForwardBackward( self.data1, grad=grad_inst, prox=prox_inst, beta_update=func_identity, ) self.fb2 = algorithms.ForwardBackward( self.data1, grad=grad_inst, prox=prox_inst, cost=cost_inst, lambda_update=None, ) self.fb3 = algorithms.ForwardBackward( self.data1, grad=grad_inst, prox=prox_inst, beta_update=func_identity, a_cd=3, ) self.fb4 = algorithms.ForwardBackward( self.data1, grad=grad_inst, prox=prox_inst, beta_update=func_identity, r_lazy=3, p_lazy=0.7, q_lazy=0.7, ) self.fb5 = algorithms.ForwardBackward( self.data1, grad=grad_inst, prox=prox_inst, restart_strategy='adaptive', xi_restart=0.9, ) self.fb6 = algorithms.ForwardBackward( self.data1, grad=grad_inst, prox=prox_inst, restart_strategy='greedy', xi_restart=0.9, min_beta=1.0, s_greedy=1.1, ) self.gfb_all_iter = algorithms.GenForwardBackward( self.data1, grad=grad_inst, prox_list=[prox_inst, prox_dual_inst], cost=None, auto_iterate=False, gamma_update=func_identity, beta_update=func_identity, ) self.gfb_all_iter.iterate(self.max_iter) self.gfb1 = algorithms.GenForwardBackward( self.data1, grad=grad_inst, prox_list=[prox_inst, prox_dual_inst], gamma_update=func_identity, lambda_update=func_identity, ) self.gfb2 = algorithms.GenForwardBackward( self.data1, grad=grad_inst, prox_list=[prox_inst, prox_dual_inst], cost=cost_inst, ) self.gfb3 = algorithms.GenForwardBackward( self.data1, grad=grad_inst, prox_list=[prox_inst, prox_dual_inst], cost=cost_inst, step_size=2, ) self.condat_all_iter = algorithms.Condat( self.data1, self.data2, grad=grad_inst, prox=prox_inst, cost=None, prox_dual=prox_dual_inst, sigma_update=func_identity, tau_update=func_identity, rho_update=func_identity, auto_iterate=False, ) self.condat_all_iter.iterate(self.max_iter) self.condat1 = algorithms.Condat( self.data1, self.data2, grad=grad_inst, prox=prox_inst, prox_dual=prox_dual_inst, sigma_update=func_identity, tau_update=func_identity, rho_update=func_identity, ) self.condat2 = algorithms.Condat( self.data1, self.data2, grad=grad_inst, prox=prox_inst, prox_dual=prox_dual_inst, linear=linear_inst, cost=cost_inst, reweight=reweight_inst, ) self.condat3 = algorithms.Condat( self.data1, self.data2, grad=grad_inst, prox=prox_inst, prox_dual=prox_dual_inst, linear=Dummy(), cost=cost_inst, auto_iterate=False, ) self.pogm_all_iter = algorithms.POGM( u=self.data1, x=self.data1, y=self.data1, z=self.data1, grad=grad_inst, prox=prox_inst, auto_iterate=False, cost=None, ) self.pogm_all_iter.iterate(self.max_iter) self.pogm1 = algorithms.POGM( u=self.data1, x=self.data1, y=self.data1, z=self.data1, grad=grad_inst, prox=prox_inst, ) self.vanilla_grad = algorithms.VanillaGenericGradOpt( self.data1, grad=grad_inst, prox=prox_inst, cost=cost_inst, ) self.ada_grad = algorithms.AdaGenericGradOpt( self.data1, grad=grad_inst, prox=prox_inst, cost=cost_inst, ) self.adam_grad = algorithms.ADAMGradOpt( self.data1, grad=grad_inst, prox=prox_inst, cost=cost_inst, ) self.momentum_grad = algorithms.MomentumGradOpt( self.data1, grad=grad_inst, prox=prox_inst, cost=cost_inst, ) self.rms_grad = algorithms.RMSpropGradOpt( self.data1, grad=grad_inst, prox=prox_inst, cost=cost_inst, ) self.saga_grad = algorithms.SAGAOptGradOpt( self.data1, grad=grad_inst, prox=prox_inst, cost=cost_inst, ) self.dummy = Dummy() self.dummy.cost = func_identity self.setup._check_operator(self.dummy.cost)
def sparse_rec_condatvu(gradient_op, linear_op, std_est=None, std_est_method=None, std_thr=2., mu=1e-6, tau=None, sigma=None, relaxation_factor=1.0, nb_of_reweights=1, max_nb_of_iter=150, add_positivity=False, atol=1e-4, verbose=0): """ The Condat-Vu sparse reconstruction with reweightings. .. note:: At the moment, supports only 2D data. Parameters ---------- data: ndarray the data to reconstruct: observation are expected in Fourier space. wavelet_name: str the wavelet name to be used during the decomposition. samples: np.ndarray the mask samples in the Fourier domain. nb_scales: int, default 4 the number of scales in the wavelet decomposition. std_est: float, default None the noise std estimate. If None use the MAD as a consistent estimator for the std. std_est_method: str, default None if the standard deviation is not set, estimate this parameter using the mad routine in the image ('primal') or in the sparse wavelet decomposition ('dual') domain. std_thr: float, default 2. use this treshold expressed as a number of sigma in the residual proximity operator during the thresholding. mu: float, default 1e-6 regularization hyperparameter. tau, sigma: float, default None parameters of the Condat-Vu proximal-dual splitting algorithm. If None estimates these parameters. relaxation_factor: float, default 0.5 parameter of the Condat-Vu proximal-dual splitting algorithm. If 1, no relaxation. nb_of_reweights: int, default 1 the number of reweightings. max_nb_of_iter: int, default 150 the maximum number of iterations in the Condat-Vu proximal-dual splitting algorithm. add_positivity: bool, default False by setting this option, set the proximity operator to identity or positive. atol: float, default 1e-4 tolerance threshold for convergence. verbose: int, default 0 the verbosity level. Returns ------- x_final: ndarray the estimated CONDAT-VU solution. transform: a WaveletTransformBase derived instance the wavelet transformation instance. """ # Check inputs # analysis = True # if hasattr(gradient_op, 'linear_op'): # analysis = False start = time.clock() if std_est_method not in (None, "primal", "dual"): raise ValueError( "Unrecognize std estimation method '{0}'.".format(std_est_method)) # Define the initial primal and dual solutions x_init = np.zeros(gradient_op.fourier_op.shape, dtype=np.complex) weights = linear_op.op(x_init) # Define the weights used during the thresholding in the dual domain, # the reweighting strategy, and the prox dual operator # Case1: estimate the noise std in the image domain if std_est_method == "primal": if std_est is None: std_est = sigma_mad(gradient_op.MtX(gradient_op.y)) weights[...] = std_thr * std_est reweight_op = cwbReweight(weights) prox_dual_op = Threshold(reweight_op.weights) # Case2: estimate the noise std in the sparse wavelet domain elif std_est_method == "dual": if std_est is None: std_est = 0.0 weights[...] = std_thr * std_est reweight_op = mReweight(weights, linear_op, thresh_factor=std_thr) prox_dual_op = Threshold(reweight_op.weights) # Case3: manual regularization mode, no reweighting else: weights[...] = mu reweight_op = None prox_dual_op = Threshold(weights) nb_of_reweights = 0 # Define the Condat Vu optimizer: define the tau and sigma in the # Condat-Vu proximal-dual splitting algorithm if not already provided. # Check also that the combination of values will lead to convergence. norm = linear_op.l2norm(gradient_op.fourier_op.shape) lipschitz_cst = gradient_op.spec_rad if sigma is None: sigma = 0.5 if tau is None: # to avoid numerics troubles with the convergence bound eps = 1.0e-8 # due to the convergence bound tau = 1.0 / (lipschitz_cst / 2 + sigma * norm**2 + eps) convergence_test = (1.0 / tau - sigma * norm**2 >= lipschitz_cst / 2.0) # Define initial primal and dual solutions primal = np.zeros(gradient_op.fourier_op.shape, dtype=np.complex) dual = linear_op.op(primal) dual[...] = 0.0 # Welcome message if verbose > 0: print(condatvu_logo()) print(" - mu: ", mu) print(" - lipschitz constant: ", gradient_op.spec_rad) print(" - tau: ", tau) print(" - sigma: ", sigma) print(" - rho: ", relaxation_factor) print(" - std: ", std_est) print(" - 1/tau - sigma||L||^2 >= beta/2: ", convergence_test) print(" - data: ", gradient_op.obs_data.shape) print(" - max iterations: ", max_nb_of_iter) print(" - number of reweights: ", nb_of_reweights) print(" - primal variable shape: ", primal.shape) print(" - dual variable shape: ", dual.shape) print("-" * 40) # Define the proximity operator if add_positivity: prox_op = Positivity() else: prox_op = Identity() # Define the cost function cost_op = DualGapCost(linear_op=linear_op, initial_cost=1e6, tolerance=1e-4, cost_interval=1, test_range=4, verbose=0, plot_output=None) # Define the optimizer opt = Condat(x=primal, y=dual, grad=gradient_op, prox=prox_op, prox_dual=prox_dual_op, linear=linear_op, cost=cost_op, rho=relaxation_factor, sigma=sigma, tau=tau, rho_update=None, sigma_update=None, tau_update=None, auto_iterate=False) # Perform the first reconstruction if verbose > 0: print("Starting optimization...") for i in range(max_nb_of_iter): opt._update() opt.x_final = opt._x_new opt.y_final = opt._y_new # Loop through the number of reweightings for reweight_index in range(nb_of_reweights): # Generate the new weights following reweighting prescription if std_est_method == "primal": reweight_op.reweight(linear_op.op(opt._x_new)) else: std_est = reweight_op.reweight(opt._x_new) # Welcome message if verbose > 0: print(" - reweight: ", reweight_index + 1) print(" - std: ", std_est) # Update the weights in the dual proximity operator prox_dual_op.weights = reweight_op.weights # Perform optimisation with new weights opt.iterate(max_iter=max_nb_of_iter) # Goodbye message end = time.clock() if verbose > 0: print(" - final iteration number: ", cost_op._iteration) print(" - final cost value: ", cost_op.cost) print(" - converged: ", opt.converge) print("Done.") print("Execution time: ", end - start, " seconds") print("-" * 40) # Get the final solution x_final = opt.x_final linear_op.transform.analysis_data = unflatten(opt.y_final, linear_op.coeffs_shape) return x_final, linear_op.transform
def sparse_deconv_condatvu(data, psf, n_iter=300, n_reweights=1): """Sparse Deconvolution with Condat-Vu Parameters ---------- data : np.ndarray Input data, 2D image psf : np.ndarray Input PSF, 2D image n_iter : int, optional Maximum number of iterations n_reweights : int, optional Number of reweightings Returns ------- np.ndarray deconvolved image """ # Print the algorithm set-up print(condatvu_logo()) # Define the wavelet filters filters = (get_cospy_filters( data.shape, transform_name='LinearWaveletTransformATrousAlgorithm')) # Set the reweighting scheme reweight = cwbReweight(get_weights(data, psf, filters)) # Set the initial variable values primal = np.ones(data.shape) dual = np.ones(filters.shape) # Set the gradient operators grad_op = GradBasic(data, lambda x: psf_convolve(x, psf), lambda x: psf_convolve(x, psf, psf_rot=True)) # Set the linear operator linear_op = WaveletConvolve2(filters) # Set the proximity operators prox_op = Positivity() prox_dual_op = SparseThreshold(linear_op, reweight.weights) # Set the cost function cost_op = costObj([grad_op, prox_op, prox_dual_op], tolerance=1e-6, cost_interval=1, plot_output=True, verbose=False) # Set the optimisation algorithm alg = Condat(primal, dual, grad_op, prox_op, prox_dual_op, linear_op, cost_op, rho=0.8, sigma=0.5, tau=0.5, auto_iterate=False) # Run the algorithm alg.iterate(max_iter=n_iter) # Implement reweigting for rw_num in range(n_reweights): print(' - Reweighting: {}'.format(rw_num + 1)) reweight.reweight(linear_op.op(alg.x_final)) alg.iterate(max_iter=n_iter) # Return the final result return alg.x_final
def _fit(self): weights = self.A comp = self.S alpha = self.alpha #### Source updates set-up #### # initialize dual variable and compute Starlet filters for Condat source updates dual_var = np.zeros((self.im_hr_shape)) if self.default_filters: self.Phi_filters = get_mr_filters(self.im_hr_shape[:2], opt=self.opt, coarse=True) rho_phi = np.sqrt( np.sum(np.sum(np.abs(self.Phi_filters), axis=(1, 2))**2)) # Set up source updates, starting with the gradient source_grad = grads.SourceGrad(self.obs_data, self.obs_weights, weights, self.flux, self.sigs, self.shift_ker_stack, self.shift_ker_stack_adj, self.upfact, self.Phi_filters) # sparsity in Starlet domain prox (this is actually assuming synthesis form) sparsity_prox = rca_prox.StarletThreshold( 0) # we'll update to the actual thresholds later # and the linear recombination for the positivity constraint lin_recombine = rca_prox.LinRecombine(weights, self.Phi_filters) #### Weight updates set-up #### # gradient weight_grad = grads.CoeffGrad(self.obs_data, self.obs_weights, comp, self.VT, self.flux, self.sigs, self.shift_ker_stack, self.shift_ker_stack_adj, self.upfact) # cost function weight_cost = costObj([weight_grad], verbose=self.modopt_verb) source_cost = costObj([source_grad], verbose=self.modopt_verb) # k-thresholding for spatial constraint iter_func = lambda x: np.floor(np.sqrt(x)) + 1 coeff_prox = rca_prox.KThreshold(iter_func) for k in range(self.nb_iter): #### Eigenpsf update #### # update gradient instance with new weights... source_grad.update_A(weights) # ... update linear recombination weights... lin_recombine.update_A(weights) # ... set optimization parameters... beta = source_grad.spec_rad + rho_phi tau = 1. / beta sigma = 1. / lin_recombine.norm * beta / 2 # ... update sparsity prox thresholds... thresh = utils.reg_format( utils.acc_sig_maps(self.shap, self.shift_ker_stack_adj, self.sigs, self.flux, self.flux_ref, self.upfact, weights, sig_data=np.ones( (self.shap[2], )) * self.sig_min)) thresholds = self.ksig * np.sqrt( np.array([ filter_convolve(Sigma_k**2, self.Phi_filters**2) for Sigma_k in thresh ])) sparsity_prox.update_threshold(tau * thresholds) # and run source update: transf_comp = utils.apply_transform(comp, self.Phi_filters) if self.nb_reweight: reweighter = cwbReweight(thresholds) for _ in range(self.nb_reweight): source_optim = optimalg.Condat(transf_comp, dual_var, source_grad, sparsity_prox, Positivity(), linear=lin_recombine, cost=source_cost, max_iter=self.nb_subiter_S, tau=tau, sigma=sigma) transf_comp = source_optim.x_final reweighter.reweight(transf_comp) thresholds = reweighter.weights else: source_optim = optimalg.Condat(transf_comp, dual_var, source_grad, sparsity_prox, Positivity(), linear=lin_recombine, cost=source_cost, max_iter=self.nb_subiter_S, tau=tau, sigma=sigma) transf_comp = source_optim.x_final comp = utils.rca_format( np.array([ filter_convolve(transf_compj, self.Phi_filters, True) for transf_compj in transf_comp ])) #TODO: replace line below with Fred's component selection (to be extracted from `low_rank_global_src_est_comb`) ind_select = range(comp.shape[2]) #### Weight update #### if k < self.nb_iter - 1: # update sources and reset iteration counter for K-thresholding weight_grad.update_S(comp) coeff_prox.reset_iter() weight_optim = optimalg.ForwardBackward( alpha, weight_grad, coeff_prox, cost=weight_cost, beta_param=weight_grad.inv_spec_rad, auto_iterate=False) weight_optim.iterate(max_iter=self.nb_subiter_weights) alpha = weight_optim.x_final weights_k = alpha.dot(self.VT) # renormalize to break scale invariance weight_norms = np.sqrt(np.sum(weights_k**2, axis=1)) comp *= weight_norms weights_k /= weight_norms.reshape(-1, 1) #TODO: replace line below with Fred's component selection ind_select = range(weights.shape[0]) weights = weights_k[ind_select, :] supports = None #TODO self.A = weights self.S = comp self.alpha = alpha source_grad.MX(transf_comp) self.current_rec = source_grad._current_rec