def message_update_beta(self, msg): k0, *pt_global, dz = msg k0 = int(k0) pt_global = tuple([int(v) for v in pt_global]) pt0 = self.workers_segments.get_local_coordinate(self.rank, pt_global) assert not self.workers_segments.is_contained_coordinate( self.rank, pt0, inner=True), (pt_global, pt0) coordinate_exist = self.workers_segments.is_contained_coordinate( self.rank, pt0, inner=False) self.coordinate_update(k0, pt0, dz, coordinate_exist=coordinate_exist) if flags.CHECK_BETA and np.random.rand() > 0.99: # Only check beta 1% of the time to avoid the check being too long inner_slice = (Ellipsis, ) + tuple([ slice(start, end) for start, end in self.local_segments.inner_bounds ]) beta, *_ = _init_beta(self.X_worker, self.D, self.reg, z_i=self.z_hat, constants=self.constants, z_positive=self.z_positive) assert np.allclose(beta[inner_slice], self.beta[inner_slice])
def test_warm_start(valid_support, atom_support, reg): tol = 1 n_atoms = 7 n_channels = 5 random_state = 36 rng = check_random_state(random_state) D = rng.randn(n_atoms, n_channels, *atom_support) D /= np.sqrt(np.sum(D * D, axis=(1, 2), keepdims=True)) z = rng.randn(n_atoms, *valid_support) z *= (rng.rand(n_atoms, *valid_support) > .7) X = reconstruct(z, D) z_hat, *_ = dicod(X, D, reg=0, z0=z, tol=tol, n_workers=N_WORKERS, max_iter=10000, verbose=VERBOSE) assert np.allclose(z_hat, z) X = rng.randn(*X.shape) z_hat, *_ = dicod(X, D, reg, z0=z, tol=tol, n_workers=N_WORKERS, max_iter=100000, verbose=VERBOSE) beta, dz_opt, _ = _init_beta(X, D, reg, z_i=z_hat) assert np.all(dz_opt <= tol)
def test_init_beta(): n_atoms = 5 n_channels = 2 height, width = 31, 37 height_atom, width_atom = 11, 13 height_valid = height - height_atom + 1 width_valid = width - width_atom + 1 rng = np.random.RandomState(42) X = rng.randn(n_channels, height, width) D = rng.randn(n_atoms, n_channels, height_atom, width_atom) D /= np.sqrt(np.sum(D * D, axis=(1, 2, 3), keepdims=True)) # z = np.zeros((n_atoms, height_valid, width_valid)) z = rng.randn(n_atoms, height_valid, width_valid) lmbd = 1 beta, dz_opt, dE = _init_beta(X, D, lmbd, z_i=z) assert beta.shape == z.shape assert dz_opt.shape == z.shape for _ in range(50): k = rng.randint(n_atoms) h = rng.randint(height_valid) w = rng.randint(width_valid) # Check that the optimal value is independent of the current value z_old = z[k, h, w] z[k, h, w] = rng.randn() beta_new, *_ = _init_beta(X, D, lmbd, z_i=z) assert np.isclose(beta_new[k, h, w], beta[k, h, w]) # Check that the chosen value is optimal z[k, h, w] = z_old + dz_opt[k, h, w] c0 = compute_objective(X, z, D, lmbd) eps = 1e-5 z[k, h, w] -= 3.5 * eps for _ in range(5): z[k, h, w] += eps assert c0 <= compute_objective(X, z, D, lmbd) z[k, h, w] = z_old
def test_stopping_criterion(n_workers, signal_support, atom_support): tol = 1 reg = 1 n_atoms = 10 n_channels = 3 rng = check_random_state(42) X = rng.randn(n_channels, *signal_support) D = rng.randn(n_atoms, n_channels, *atom_support) sum_axis = tuple(range(1, D.ndim)) D /= np.sqrt(np.sum(D * D, axis=sum_axis, keepdims=True)) z_hat, *_ = dicod(X, D, reg, tol=tol, n_workers=n_workers, verbose=VERBOSE) beta, dz_opt, _ = _init_beta(X, D, reg, z_i=z_hat) assert abs(dz_opt).max() < tol
def init_cd_variables(self): t_start = time.time() # Pre-compute some quantities constants = {} constants['norm_atoms'] = compute_norm_atoms(self.D) constants['DtD'] = compute_DtD(self.D) self.constants = constants # List of all pending messages sent self.messages = [] # Log all updates for logging purpose self._log_updates = [] # Avoid printing progress too often self._last_progress = 0 # Initialization of the auxillary variable for LGCD return_dE = self.strategy == "gs-q" self.beta, self.dz_opt, self.dE = _init_beta( self.X_worker, self.D, self.reg, z_i=self.z0, constants=constants, z_positive=self.z_positive, return_dE=return_dE) # Make sure all segments are activated self.local_segments.reset() if self.z0 is not None: self.freezed_support = None self.correct_beta_z0() if flags.CHECK_WARM_BETA: pt_global = self.workers_segments.get_seg_shape(0, inner=True) pt = self.workers_segments.get_local_coordinate( self.rank, pt_global) if self.workers_segments.is_contained_coordinate(self.rank, pt): _, _, *atom_support = self.D.shape beta_slice = (Ellipsis, ) + tuple([ slice(v - size_ax + 1, v + size_ax - 1) for v, size_ax in zip(pt, atom_support) ]) sum_beta = np.array(self.beta[beta_slice].sum(), dtype='d') self.return_array(sum_beta) if self.freeze_support: assert self.z0 is not None self.freezed_support = self.z0 == 0 self.dz_opt[self.freezed_support] = 0 else: self.freezed_support = None self.synchronize_workers() t_local_init = time.time() - t_start self.info("End local initialization in {:.2f}s", t_local_init, global_msg=True)
def init_cd_variables(self): t_start = time.time() # Pre-compute some quantities constants = {} constants['norm_atoms'] = compute_norm_atoms(self.D) if self.precomputed_DtD: constants['DtD'] = self.DtD else: constants['DtD'] = compute_DtD(self.D) self.constants = constants # List of all pending messages sent self.messages = [] # Log all updates for logging purpose self._log_updates = [] # Avoid printing progress too often self._last_progress = 0 if self.warm_start and hasattr(self, 'z_hat'): self.z0 = self.z_hat.copy() # Initialization of the auxillary variable for LGCD self.beta, self.dz_opt, self.dE = _init_beta( self.X_worker, self.D, self.reg, z_i=self.z0, constants=constants, z_positive=self.z_positive, return_dE=self.strategy == "gs-q") # Make sure all segments are activated self.local_segments.reset() if self.z0 is not None: self.freezed_support = None self.z_hat = self.z0.copy() self.correct_beta_z0() else: self.z_hat = np.zeros(self.beta.shape) if flags.CHECK_WARM_BETA: worker_check_beta(self.rank, self.workers_segments, self.beta, self.D.shape) if self.freeze_support: assert self.z0 is not None self.freezed_support = self.z0 == 0 self.dz_opt[self.freezed_support] = 0 else: self.freezed_support = None self.synchronize_workers(with_main=False) t_local_init = time.time() - t_start self.debug("End local initialization in {:.2f}s", t_local_init, global_msg=True) self.info( "Start DICOD with {} workers, strategy '{}', soft_lock" "={} and n_seg={}({})", self.n_workers, self.strategy, self.soft_lock, self.n_seg, self.local_segments.effective_n_seg, global_msg=True) return t_local_init