def test_segmentation_coverage(): sig_support = (108, 53) for h_seg in [5, 7, 9, 13, 17]: for w_seg in [3, 11]: z = np.zeros(sig_support) segments = Segmentation(n_seg=(h_seg, w_seg), signal_support=sig_support) assert tuple(segments.n_seg_per_axis) == (h_seg, w_seg) seg_slice = segments.get_seg_slice(0) seg_support = segments.get_seg_support(0) assert seg_support == z[seg_slice].shape z[seg_slice] += 1 i_seg = segments.increment_seg(0) while i_seg != 0: seg_slice = segments.get_seg_slice(i_seg) seg_support = segments.get_seg_support(i_seg) assert seg_support == z[seg_slice].shape z[seg_slice] += 1 i_seg = segments.increment_seg(i_seg) assert np.all(z == 1) z = np.zeros(sig_support) inner_bounds = [(8, 100), (3, 50)] inner_slice = tuple([slice(start, end) for start, end in inner_bounds]) segments = Segmentation(n_seg=7, inner_bounds=inner_bounds, full_support=sig_support) for i_seg in range(segments.effective_n_seg): seg_slice = segments.get_seg_slice(i_seg) z[seg_slice] += 1 assert np.all(z[inner_slice] == 1) z[inner_slice] = 0 assert np.all(z == 0)
def recv_signal(self): n_atoms, n_channels, *atom_support = self.D.shape comm = MPI.Comm.Get_parent() X_info = comm.bcast(None, root=0) self.has_z0 = X_info['has_z0'] self.valid_support = X_info['valid_support'] self.workers_topology = X_info['workers_topology'] self.size_msg = len(self.workers_topology) + 2 self.workers_segments = Segmentation(n_seg=self.workers_topology, signal_support=self.valid_support, overlap=self.overlap) # Receive X and z from the master node. worker_support = self.workers_segments.get_seg_support(self.rank) X_shape = (n_channels, ) + get_full_support(worker_support, atom_support) z0_shape = (n_atoms, ) + worker_support if self.has_z0: z0 = self.recv_array(z0_shape) else: z0 = None X_worker = self.recv_array(X_shape) # Compute the local segmentation for LGCD algorithm # If n_seg is not specified, compute the shape of the local segments # as the size of an interfering zone. n_atoms, _, *atom_support = self.D.shape n_seg = self.n_seg local_seg_support = None if self.n_seg == 'auto': n_seg = None local_seg_support = 2 * np.array(atom_support) - 1 # Get local inner bounds. First, compute the seg_bound without overlap # in local coordinates and then convert the bounds in the local # coordinate system. inner_bounds = self.workers_segments.get_seg_bounds(self.rank, inner=True) inner_bounds = np.transpose([ self.workers_segments.get_local_coordinate(self.rank, bound) for bound in np.transpose(inner_bounds) ]) worker_support = self.workers_segments.get_seg_support(self.rank) self.local_segments = Segmentation(n_seg=n_seg, seg_support=local_seg_support, inner_bounds=inner_bounds, full_support=worker_support) self.synchronize_workers(with_main=True) return X_worker, z0
def test_touched_segments(): """Test detection of touched segments and records of active segments """ rng = np.random.RandomState(42) H, W = sig_support = (108, 53) n_seg = (9, 3) for h_radius in [5, 7, 9]: for w_radius in [3, 11]: for _ in range(20): h0 = rng.randint(-h_radius, sig_support[0] + h_radius) w0 = rng.randint(-w_radius, sig_support[1] + w_radius) z = np.zeros(sig_support) segments = Segmentation(n_seg, signal_support=sig_support) touched_slice = ( slice(max(0, h0 - h_radius), min(H, h0 + h_radius + 1)), slice(max(0, w0 - w_radius), min(W, w0 + w_radius + 1)) ) z[touched_slice] = 1 touched_segments = segments.get_touched_segments( (h0, w0), (h_radius, w_radius)) segments.set_inactive_segments(touched_segments) n_active_segments = segments._n_active_segments expected_n_active_segments = segments.effective_n_seg for i_seg in range(segments.effective_n_seg): seg_slice = segments.get_seg_slice(i_seg) is_touched = np.any(z[seg_slice] == 1) expected_n_active_segments -= is_touched assert segments.is_active_segment(i_seg) != is_touched assert n_active_segments == expected_n_active_segments # Check an error is returned when touched radius is larger than seg_size segments = Segmentation(n_seg, signal_support=sig_support) with pytest.raises(ValueError, match="too large"): segments.get_touched_segments((0, 0), (30, 2))
def test_padding_to_overlap(): n_seg = (4, 4) sig_shape = (504, 504) overlap = (12, 7) seg = Segmentation(n_seg=n_seg, signal_shape=sig_shape, overlap=overlap) seg_shape_all = seg.get_seg_shape(n_seg[1] + 1) for i_seg in range(np.prod(n_seg)): seg_shape = seg.get_seg_shape(i_seg) z = np.empty(seg_shape) overlap = seg.get_padding_to_overlap(i_seg) z = np.pad(z, overlap, mode='constant') assert z.shape == seg_shape_all
def test_touched_overlap_area(): sig_shape = (505, 407) overlap = (11, 9) n_seg = (8, 4) segments = Segmentation(n_seg=n_seg, signal_shape=sig_shape, overlap=overlap) for i_seg in range(segments.effective_n_seg): seg_shape = segments.get_seg_shape(i_seg) seg_slice = segments.get_seg_slice(i_seg) seg_inner_slice = segments.get_seg_slice(i_seg, inner=True) if i_seg != 0: with pytest.raises(AssertionError): segments.check_area_contained(i_seg, (0, 0), overlap) for pt0 in [ overlap, (overlap[0], 25), (25, overlap[1]), (25, 25), (seg_shape[0] - overlap[0] - 1, 25), (25, seg_shape[1] - overlap[1] - 1), (seg_shape[0] - overlap[0] - 1, seg_shape[1] - overlap[1] - 1) ]: assert segments.is_contained_coordinate(i_seg, pt0, inner=True) segments.check_area_contained(i_seg, pt0, overlap) z = np.zeros(sig_shape) pt_global = segments.get_global_coordinate(i_seg, pt0) update_slice = tuple([ slice(max(v - r, 0), v + r + 1) for v, r in zip(pt_global, overlap) ]) z[update_slice] += 1 z[seg_inner_slice] = 0 # The returned slice are given in local coordinates. Take the # segment in z to use local coordinate. z_seg = z[seg_slice] updated_slices = segments.get_touched_overlap_slices( i_seg, pt0, overlap) # Assert that all selected coordinate are indeed in the update area for u_slice in updated_slices: assert np.all(z_seg[u_slice] == 1) # Assert that all coordinate updated in the overlap area have been # selected with at least one slice. for u_slice in updated_slices: z_seg[u_slice] *= 0 assert np.all(z == 0)
def test_change_coordinate(): sig_support = (505, 407) overlap = (12, 7) n_seg = (4, 4) segments = Segmentation(n_seg=n_seg, signal_support=sig_support, overlap=overlap) for i_seg in range(segments.effective_n_seg): seg_bound = segments.get_seg_bounds(i_seg) seg_support = segments.get_seg_support(i_seg) origin = tuple([start for start, _ in seg_bound]) assert segments.get_global_coordinate(i_seg, (0, 0)) == origin assert segments.get_local_coordinate(i_seg, origin) == (0, 0) corner = tuple([end for _, end in seg_bound]) assert segments.get_global_coordinate(i_seg, seg_support) == corner assert segments.get_local_coordinate(i_seg, corner) == seg_support
def test_segmentation_coverage_overlap(): sig_support = (505, 407) for overlap in [(3, 0), (0, 5), (3, 5), (12, 7)]: for h_seg in [5, 7, 9, 13, 15, 17]: for w_seg in [3, 11]: segments = Segmentation(n_seg=(h_seg, w_seg), signal_support=sig_support, overlap=overlap) z = np.zeros(sig_support) for i_seg in range(segments.effective_n_seg): seg_slice = segments.get_seg_slice(i_seg, inner=True) z[seg_slice] += 1 i_seg = segments.increment_seg(i_seg) non_overlapping = np.prod(sig_support) assert np.sum(z == 1) == non_overlapping z = np.zeros(sig_support) for i_seg in range(segments.effective_n_seg): seg_slice = segments.get_seg_slice(i_seg) z[seg_slice] += 1 i_seg = segments.increment_seg(i_seg) h_ov, w_ov = overlap h_seg, w_seg = segments.n_seg_per_axis expected_overlap = ((h_seg - 1) * sig_support[1] * 2 * h_ov) expected_overlap += ((w_seg - 1) * sig_support[0] * 2 * w_ov) # Compute the number of pixel where there is more than 2 # segments overlappping. corner_overlap = 4 * (h_seg - 1) * (w_seg - 1) * h_ov * w_ov expected_overlap -= 2 * corner_overlap non_overlapping -= expected_overlap + corner_overlap assert non_overlapping == np.sum(z == 1) assert expected_overlap == np.sum(z == 2) assert corner_overlap == np.sum(z == 4)
def recv_task(self): # Retrieve different constants from the base communicator and store # then in the class. params = self.get_params() if self.timeout: self.timeout *= 3 self.random_state = params['random_state'] if isinstance(self.random_state, int): self.random_state += self.rank self.size_msg = len(params['valid_shape']) + 2 # Compute the shape of the worker segment. self.D = self.get_D() n_atoms, n_channels, *atom_support = self.D.shape self.overlap = np.array(atom_support) - 1 self.workers_segments = Segmentation( n_seg=params['workers_topology'], signal_shape=params['valid_shape'], overlap=self.overlap) # Receive X and z from the master node. worker_shape = self.workers_segments.get_seg_shape(self.rank) X_shape = (n_channels, ) + get_full_shape(worker_shape, atom_support) if params['has_z0']: z0_shape = (n_atoms, ) + worker_shape self.z0 = self.get_signal(z0_shape, params['debug']) else: self.z0 = None self.X_worker = self.get_signal(X_shape, params['debug']) # Compute the local segmentation for LGCD algorithm # If n_seg is not specified, compute the shape of the local segments # as the size of an interfering zone. n_seg = self.n_seg local_seg_shape = None if self.n_seg == 'auto': n_seg = None local_seg_shape = 2 * np.array(atom_support) - 1 # Get local inner bounds. First, compute the seg_bound without overlap # in local coordinates and then convert the bounds in the local # coordinate system. inner_bounds = self.workers_segments.get_seg_bounds(self.rank, inner=True) inner_bounds = np.transpose([ self.workers_segments.get_local_coordinate(self.rank, bound) for bound in np.transpose(inner_bounds) ]) self.local_segments = Segmentation(n_seg=n_seg, seg_shape=local_seg_shape, inner_bounds=inner_bounds, full_shape=worker_shape) # Initialize the solution n_atoms = self.D.shape[0] seg_shape = self.workers_segments.get_seg_shape(self.rank) if self.z0 is None: self.z_hat = np.zeros((n_atoms, ) + seg_shape) else: self.z_hat = self.z0 self.info( "Start DICOD with {} workers, strategy '{}', soft_lock" "={} and n_seg={}({})", self.n_jobs, self.strategy, self.soft_lock, self.n_seg, self.local_segments.effective_n_seg, global_msg=True) self.synchronize_workers()
def test_inner_coordinate(): sig_support = (505, 407) overlap = (11, 11) n_seg = (4, 4) segments = Segmentation(n_seg=n_seg, signal_support=sig_support, overlap=overlap) for h_rank in range(n_seg[0]): for w_rank in range(n_seg[1]): i_seg = h_rank * n_seg[1] + w_rank seg_support = segments.get_seg_support(i_seg) assert segments.is_contained_coordinate(i_seg, overlap, inner=True) if h_rank == 0: assert segments.is_contained_coordinate(i_seg, (0, overlap[1]), inner=True) else: assert not segments.is_contained_coordinate( i_seg, (overlap[0] - 1, overlap[1]), inner=True) if w_rank == 0: assert segments.is_contained_coordinate(i_seg, (overlap[0], 0), inner=True) else: assert not segments.is_contained_coordinate( i_seg, (overlap[0], overlap[1] - 1), inner=True) if h_rank == 0 and w_rank == 0: assert segments.is_contained_coordinate(i_seg, (0, 0), inner=True) else: assert not segments.is_contained_coordinate( i_seg, (overlap[0] - 1, overlap[1] - 1), inner=True) if h_rank == n_seg[0] - 1: assert segments.is_contained_coordinate( i_seg, (seg_support[0] - 1, seg_support[1] - overlap[1] - 1), inner=True) else: assert not segments.is_contained_coordinate( i_seg, (seg_support[0] - overlap[0], seg_support[1] - overlap[1] - 1), inner=True) if w_rank == n_seg[1] - 1: assert segments.is_contained_coordinate( i_seg, (seg_support[0] - overlap[0] - 1, seg_support[1] - 1), inner=True) else: assert not segments.is_contained_coordinate( i_seg, (seg_support[0] - overlap[0] - 1, seg_support[1] - overlap[1]), inner=True) if h_rank == n_seg[0] - 1 and w_rank == n_seg[1] - 1: assert segments.is_contained_coordinate( i_seg, (seg_support[0] - 1, seg_support[1] - 1), inner=True) else: assert not segments.is_contained_coordinate( i_seg, (seg_support[0] - overlap[0], seg_support[1] - overlap[1]), inner=True)
atom_support = (16, 16) run_args = (n_atoms, atom_support, reg, tol, n_workers, random_state) if args.no_cache: X_hat, pobj = run_without_soft_lock.call(*run_args)[0] else: X_hat, pobj = run_without_soft_lock(*run_args) file_name = f"soft_lock_M{n_workers}_support{atom_support[0]}" np.save(f"benchmarks_results/{file_name}_X_hat.npy", X_hat) # Compute the worker segmentation for the image, n_channels, *sig_support = X_hat.shape valid_support = get_valid_support(sig_support, atom_support) workers_segments = Segmentation(n_seg=(w_world, w_world), signal_support=valid_support, overlap=0) fig = plt.figure("recovery") fig.patch.set_alpha(0) ax = plt.subplot() ax.imshow(X_hat.swapaxes(0, 2)) for i_seg in range(workers_segments.effective_n_seg): seg_bounds = np.array(workers_segments.get_seg_bounds(i_seg)) seg_bounds = seg_bounds + np.array(atom_support) / 2 ax.vlines(seg_bounds[1], *seg_bounds[0], linestyle='--') ax.hlines(seg_bounds[0], *seg_bounds[1], linestyle='--') ax.axis('off') plt.tight_layout()
def coordinate_descent(X_i, D, reg, z0=None, DtD=None, n_seg='auto', strategy='greedy', tol=1e-5, max_iter=100000, timeout=None, z_positive=False, freeze_support=False, return_ztz=False, timing=False, random_state=None, verbose=0): """Coordinate Descent Algorithm for 2D convolutional sparse coding. Parameters ---------- X_i : ndarray, shape (n_channels, *sig_support) Image to encode on the dictionary D D : ndarray, shape (n_atoms, n_channels, *atom_support) Current dictionary for the sparse coding reg : float Regularization parameter z0 : ndarray, shape (n_atoms, *valid_support) or None Warm start value for z_hat. If not present, z_hat is initialized to 0. DtD : ndarray, shape (n_atoms, n_atoms, 2 * valid_support - 1) or None Warm start value for DtD. If not present, it is computed on init. n_seg : int or 'auto' Number of segments to use for each dimension. If set to 'auto' use segments of twice the size of the dictionary. strategy : str in {strategies} Coordinate selection scheme for the coordinate descent. If set to 'greedy'|'gs-r', the coordinate with the largest value for dz_opt is selected. If set to 'random, the coordinate is chosen uniformly on the segment. If set to 'gs-q', the value that reduce the most the cost function is selected. In this case, dE must holds the value of this cost reduction. tol : float Tolerance for the minimal update size in this algorithm. max_iter : int Maximal number of iteration run by this algorithm. z_positive : boolean If set to true, the activations are constrained to be positive. freeze_support : boolean If set to True, only update the coefficient that are non-zero in z0. return_ztz : boolean If True, returns the constants ztz and ztX, used to compute D-updates. timing : boolean If set to True, log the cost and timing information. random_state : None or int or RandomState current random state to seed the random number generator. verbose : int Verbosity level of the algorithm. Return ------ z_hat : ndarray, shape (n_atoms, *valid_support) Activation associated to X_i for the given dictionary D """ n_channels, *sig_support = X_i.shape n_atoms, n_channels, *atom_support = D.shape valid_support = get_valid_support(sig_support, atom_support) if strategy not in STRATEGIES: raise ValueError("'The coordinate selection strategy should be in " "{}. Got '{}'.".format(STRATEGIES, strategy)) # compute sizes for the segments for LGCD. Auto gives segments of size # twice the support of the atoms. if n_seg == 'auto': n_seg = np.array(valid_support) // (2 * np.array(atom_support) - 1) n_seg = tuple(np.maximum(1, n_seg)) segments = Segmentation(n_seg, signal_support=valid_support) # Pre-compute constants for maintaining the auxillary variable beta and # compute the coordinate update values. constants = {} constants['norm_atoms'] = compute_norm_atoms(D) if DtD is None: constants['DtD'] = compute_DtD(D) else: constants['DtD'] = DtD # Initialization of the algorithm variables i_seg = -1 accumulator = 0 if z0 is None: z_hat = np.zeros((n_atoms,) + valid_support) else: z_hat = np.copy(z0) n_coordinates = z_hat.size # Get a random number genator from the given random_state rng = check_random_state(random_state) order = None if strategy in ['cyclic', 'cyclic-r', 'random']: order = get_order_iterator(z_hat.shape, strategy=strategy, random_state=rng) t_start_init = time.time() return_dE = strategy == "gs-q" beta, dz_opt, dE = _init_beta(X_i, D, reg, z_i=z0, constants=constants, z_positive=z_positive, return_dE=return_dE) if strategy == "gs-q": raise NotImplementedError("This is still WIP") if freeze_support: freezed_support = z0 == 0 dz_opt[freezed_support] = 0 else: freezed_support = None p_obj, next_log_iter = [], 1 t_init = time.time() - t_start_init if timing: p_obj.append((0, t_init, 0, compute_objective(X_i, z_hat, D, reg))) n_coordinate_updates = 0 t_run = 0 t_select_coord, t_update_coord = [], [] t_start = time.time() if timeout is not None: deadline = t_start + timeout else: deadline = None for ii in range(max_iter): if ii % 1000 == 0 and verbose > 0: print("\r[LGCD:PROGRESS] {:.0f}s - {:7.2%} iterations" .format(t_run, ii / max_iter), end='', flush=True) i_seg = segments.increment_seg(i_seg) if segments.is_active_segment(i_seg): t_start_selection = time.time() k0, pt0, dz = _select_coordinate(dz_opt, dE, segments, i_seg, strategy=strategy, order=order) selection_duration = time.time() - t_start_selection t_select_coord.append(selection_duration) t_run += selection_duration else: dz = 0 accumulator = max(abs(dz), accumulator) # Update the selected coordinate and beta, only if the update is # greater than the convergence tolerance. if abs(dz) > tol: t_start_update = time.time() # update the current solution estimate and beta beta, dz_opt, dE = coordinate_update( k0, pt0, dz, beta=beta, dz_opt=dz_opt, dE=dE, z_hat=z_hat, D=D, reg=reg, constants=constants, z_positive=z_positive, freezed_support=freezed_support) touched_segs = segments.get_touched_segments( pt=pt0, radius=atom_support) n_changed_status = segments.set_active_segments(touched_segs) # Logging of the time and the cost function if necessary update_duration = time.time() - t_start_update n_coordinate_updates += 1 t_run += update_duration t_update_coord.append(update_duration) if timing and ii + 1 >= next_log_iter: p_obj.append((ii + 1, t_run, np.sum(t_select_coord), compute_objective(X_i, z_hat, D, reg))) next_log_iter = next_log_iter * 1.3 # If debug flag CHECK_ACTIVE_SEGMENTS is set, check that all # inactive segments should be inactive if flags.CHECK_ACTIVE_SEGMENTS and n_changed_status: segments.test_active_segment(dz_opt, tol) elif strategy in ["greedy", 'gs-r']: segments.set_inactive_segments(i_seg) # check stopping criterion if _check_convergence(segments, tol, ii, dz_opt, n_coordinates, strategy, accumulator=accumulator): assert np.all(abs(dz_opt) <= tol) if verbose > 0: print("\r[LGCD:INFO] converged in {} iterations ({} updates)" .format(ii + 1, n_coordinate_updates)) break # Check is we reach the timeout if deadline is not None and time.time() >= deadline: if verbose > 0: print("\r[LGCD:INFO] Reached timeout. Done {} iterations " "({} updates). Max of |dz|={}." .format(ii + 1, n_coordinate_updates, abs(dz_opt).max())) break else: if verbose > 0: print("\r[LGCD:INFO] Reached max_iter. Done {} coordinate " "updates. Max of |dz|={}." .format(n_coordinate_updates, abs(dz_opt).max())) print(f"\r[LGCD:{strategy}] " f"t_select={np.mean(t_select_coord):.3e}s " f"t_update={np.mean(t_update_coord):.3e}s" ) runtime = time.time() - t_start if verbose > 0: print("\r[LGCD:INFO] done in {:.3f}s ({:.3f}s)" .format(runtime, t_run)) ztz, ztX = None, None if return_ztz: ztz = compute_ztz(z_hat, atom_support) ztX = compute_ztX(z_hat, X_i) p_obj.append([n_coordinate_updates, t_run, compute_objective(X_i, z_hat, D, reg)]) run_statistics = dict(iterations=ii + 1, runtime=runtime, t_init=t_init, t_run=t_run, n_updates=n_coordinate_updates, t_select=np.mean(t_select_coord), t_update=np.mean(t_update_coord)) return z_hat, ztz, ztX, p_obj, run_statistics
atom_support = (16, 16) run_args = (n_atoms, atom_support, reg, tol, n_jobs, random_state) if args.no_cache: X_hat, pobj = run_without_soft_lock.call(*run_args) else: X_hat, pobj = run_without_soft_lock(*run_args) file_name = f"soft_lock_M{n_jobs}_support{atom_support[0]}" np.save(f"benchmarks_results/{file_name}_X_hat.npy", X_hat) # Compute the worker segmentation for the image, n_channels, *sig_shape = X_hat.shape valid_shape = get_valid_shape(sig_shape, atom_support) workers_segments = Segmentation(n_seg=(w_world, w_world), signal_shape=valid_shape, overlap=0) fig = plt.figure("recovery") fig.patch.set_alpha(0) ax = plt.subplot() ax.imshow(X_hat.swapaxes(0, 2)) for i_seg in range(workers_segments.effective_n_seg): seg_bounds = np.array(workers_segments.get_seg_bounds(i_seg)) seg_bounds = seg_bounds + np.array(atom_support) / 2 ax.vlines(seg_bounds[1], *seg_bounds[0], linestyle='--') ax.hlines(seg_bounds[0], *seg_bounds[1], linestyle='--') ax.axis('off') plt.tight_layout()
def coordinate_descent(X_i, D, reg, z0=None, n_seg='auto', strategy='greedy', tol=1e-5, max_iter=100000, timeout=None, z_positive=False, freeze_support=False, return_ztz=False, timing=False, random_state=None, verbose=0): """Coordinate Descent Algorithm for 2D convolutional sparse coding. Parameters ---------- X_i : ndarray, shape (n_channels, *sig_shape) Image to encode on the dictionary D z_i : ndarray, shape (n_atoms, *valid_shape) Warm start value for z_hat D : ndarray, shape (n_atoms, n_channels, *atom_shape) Current dictionary for the sparse coding reg : float Regularization parameter n_seg : int or { 'auto' } Number of segments to use for each dimension. If set to 'auto' use segments of twice the size of the dictionary. tol : float Tolerance for the minimal update size in this algorithm. strategy : str in { 'greedy' | 'random' | 'gs-r' | 'gs-q' } Coordinate selection scheme for the coordinate descent. If set to 'greedy'|'gs-r', the coordinate with the largest value for dz_opt is selected. If set to 'random, the coordinate is chosen uniformly on the segment. If set to 'gs-q', the value that reduce the most the cost function is selected. In this case, dE must holds the value of this cost reduction. max_iter : int Maximal number of iteration run by this algorithm. z_positive : boolean If set to true, the activations are constrained to be positive. freeze_support : boolean If set to True, only update the coefficient that are non-zero in z0. timing : boolean If set to True, log the cost and timing information. random_state : None or int or RandomState current random state to seed the random number generator. verbose : int Verbosity level of the algorithm. Return ------ z_hat : ndarray, shape (n_atoms, *valid_shape) Activation associated to X_i for the given dictionary D """ n_channels, *sig_shape = X_i.shape n_atoms, n_channels, *atom_shape = D.shape valid_shape = tuple([ size_ax - size_atom_ax + 1 for size_ax, size_atom_ax in zip(sig_shape, atom_shape) ]) # compute sizes for the segments for LGCD if n_seg == 'auto': n_seg = [] for axis_size, atom_size in zip(valid_shape, atom_shape): n_seg.append(max(axis_size // (2 * atom_size - 1), 1)) segments = Segmentation(n_seg, signal_shape=valid_shape) # Pre-compute some quantities constants = {} constants['norm_atoms'] = compute_norm_atoms(D) constants['DtD'] = compute_DtD(D) # Initialization of the algorithm variables i_seg = -1 p_obj, next_cost = [], 1 accumulator = 0 if z0 is None: z_hat = np.zeros((n_atoms, ) + valid_shape) else: z_hat = np.copy(z0) n_coordinates = z_hat.size t_update = 0 t_start_update = time.time() return_dE = strategy == "gs-q" beta, dz_opt, dE = _init_beta(X_i, D, reg, z_i=z0, constants=constants, z_positive=z_positive, return_dE=return_dE) if strategy == "gs-q": raise NotImplementedError("This is still WIP") if freeze_support: freezed_support = z0 == 0 dz_opt[freezed_support] = 0 else: freezed_support = None t_start = time.time() n_coordinate_updates = 0 if timeout is not None: deadline = t_start + timeout else: deadline = None for ii in range(max_iter): if ii % 1000 == 0 and verbose > 0: print("\r[LGCD:PROGRESS] {:.0f}s - {:7.2%} iterations".format( t_update, ii / max_iter), end='', flush=True) i_seg = segments.increment_seg(i_seg) if segments.is_active_segment(i_seg): k0, pt0, dz = _select_coordinate(dz_opt, dE, segments, i_seg, strategy=strategy, random_state=random_state) else: k0, pt0, dz = None, None, 0 accumulator = max(abs(dz), accumulator) # Update the selected coordinate and beta, only if the update is # greater than the convergence tolerance. if abs(dz) > tol: n_coordinate_updates += 1 # update beta beta, dz_opt, dE = coordinate_update( k0, pt0, dz, beta=beta, dz_opt=dz_opt, dE=dE, z_hat=z_hat, D=D, reg=reg, constants=constants, z_positive=z_positive, freezed_support=freezed_support) touched_segs = segments.get_touched_segments(pt=pt0, radius=atom_shape) n_changed_status = segments.set_active_segments(touched_segs) if flags.CHECK_ACTIVE_SEGMENTS and n_changed_status: segments.test_active_segment(dz_opt, tol) t_update += time.time() - t_start_update if timing: if ii >= next_cost: p_obj.append( (ii, t_update, compute_objective(X_i, z_hat, D, reg))) next_cost = next_cost * 2 t_start_update = time.time() elif strategy in ["greedy", 'gs-r']: segments.set_inactive_segments(i_seg) # check stopping criterion if _check_convergence(segments, tol, ii, dz_opt, n_coordinates, strategy, accumulator=accumulator): assert np.all(abs(dz_opt) <= tol) if verbose > 0: print("\r[LGCD:INFO] converged after {} iterations".format(ii + 1)) break # Check is we reach the timeout if deadline is not None and time.time() >= deadline: if verbose > 0: print("\r[LGCD:INFO] Reached timeout. Done {} coordinate " "updates. Max of |dz|={}.".format( n_coordinate_updates, abs(dz_opt).max())) break else: if verbose > 0: print("\r[LGCD:INFO] Reached max_iter. Done {} coordinate " "updates. Max of |dz|={}.".format(n_coordinate_updates, abs(dz_opt).max())) runtime = time.time() - t_start if verbose > 0: print("\r[LGCD:INFO] done in {:.3}s".format(runtime)) ztz, ztX = None, None if return_ztz: ztz = compute_ztz(z_hat, atom_shape) ztX = compute_ztX(z_hat, X_i) p_obj.append([n_coordinate_updates, t_update, None]) return z_hat, ztz, ztX, p_obj, None