예제 #1
0
    def message_update_beta(self, msg):
        k0, *pt_global, dz = msg

        k0 = int(k0)
        pt_global = tuple([int(v) for v in pt_global])
        pt0 = self.workers_segments.get_local_coordinate(self.rank, pt_global)
        assert not self.workers_segments.is_contained_coordinate(
            self.rank, pt0, inner=True), (pt_global, pt0)
        coordinate_exist = self.workers_segments.is_contained_coordinate(
            self.rank, pt0, inner=False)
        self.coordinate_update(k0, pt0, dz, coordinate_exist=coordinate_exist)

        if flags.CHECK_BETA and np.random.rand() > 0.99:
            # Only check beta 1% of the time to avoid the check being too long
            inner_slice = (Ellipsis, ) + tuple([
                slice(start, end)
                for start, end in self.local_segments.inner_bounds
            ])
            beta, *_ = _init_beta(self.X_worker,
                                  self.D,
                                  self.reg,
                                  z_i=self.z_hat,
                                  constants=self.constants,
                                  z_positive=self.z_positive)
            assert np.allclose(beta[inner_slice], self.beta[inner_slice])
예제 #2
0
def test_warm_start(valid_support, atom_support, reg):
    tol = 1
    n_atoms = 7
    n_channels = 5
    random_state = 36

    rng = check_random_state(random_state)

    D = rng.randn(n_atoms, n_channels, *atom_support)
    D /= np.sqrt(np.sum(D * D, axis=(1, 2), keepdims=True))
    z = rng.randn(n_atoms, *valid_support)
    z *= (rng.rand(n_atoms, *valid_support) > .7)

    X = reconstruct(z, D)

    z_hat, *_ = dicod(X, D, reg=0, z0=z, tol=tol, n_workers=N_WORKERS,
                      max_iter=10000, verbose=VERBOSE)
    assert np.allclose(z_hat, z)

    X = rng.randn(*X.shape)

    z_hat, *_ = dicod(X, D, reg, z0=z, tol=tol, n_workers=N_WORKERS,
                      max_iter=100000, verbose=VERBOSE)
    beta, dz_opt, _ = _init_beta(X, D, reg, z_i=z_hat)
    assert np.all(dz_opt <= tol)
예제 #3
0
def test_init_beta():
    n_atoms = 5
    n_channels = 2
    height, width = 31, 37
    height_atom, width_atom = 11, 13
    height_valid = height - height_atom + 1
    width_valid = width - width_atom + 1

    rng = np.random.RandomState(42)

    X = rng.randn(n_channels, height, width)
    D = rng.randn(n_atoms, n_channels, height_atom, width_atom)
    D /= np.sqrt(np.sum(D * D, axis=(1, 2, 3), keepdims=True))
    # z = np.zeros((n_atoms, height_valid, width_valid))
    z = rng.randn(n_atoms, height_valid, width_valid)

    lmbd = 1
    beta, dz_opt, dE = _init_beta(X, D, lmbd, z_i=z)

    assert beta.shape == z.shape
    assert dz_opt.shape == z.shape

    for _ in range(50):
        k = rng.randint(n_atoms)
        h = rng.randint(height_valid)
        w = rng.randint(width_valid)

        # Check that the optimal value is independent of the current value
        z_old = z[k, h, w]
        z[k, h, w] = rng.randn()
        beta_new, *_ = _init_beta(X, D, lmbd, z_i=z)
        assert np.isclose(beta_new[k, h, w], beta[k, h, w])

        # Check that the chosen value is optimal
        z[k, h, w] = z_old + dz_opt[k, h, w]
        c0 = compute_objective(X, z, D, lmbd)

        eps = 1e-5
        z[k, h, w] -= 3.5 * eps
        for _ in range(5):
            z[k, h, w] += eps
            assert c0 <= compute_objective(X, z, D, lmbd)
        z[k, h, w] = z_old
예제 #4
0
def test_stopping_criterion(n_workers, signal_support, atom_support):
    tol = 1
    reg = 1
    n_atoms = 10
    n_channels = 3

    rng = check_random_state(42)

    X = rng.randn(n_channels, *signal_support)
    D = rng.randn(n_atoms, n_channels, *atom_support)
    sum_axis = tuple(range(1, D.ndim))
    D /= np.sqrt(np.sum(D * D, axis=sum_axis, keepdims=True))

    z_hat, *_ = dicod(X, D, reg, tol=tol, n_workers=n_workers, verbose=VERBOSE)

    beta, dz_opt, _ = _init_beta(X, D, reg, z_i=z_hat)
    assert abs(dz_opt).max() < tol
예제 #5
0
    def init_cd_variables(self):
        t_start = time.time()

        # Pre-compute some quantities
        constants = {}
        constants['norm_atoms'] = compute_norm_atoms(self.D)
        constants['DtD'] = compute_DtD(self.D)
        self.constants = constants

        # List of all pending messages sent
        self.messages = []

        # Log all updates for logging purpose
        self._log_updates = []

        # Avoid printing progress too often
        self._last_progress = 0

        # Initialization of the auxillary variable for LGCD
        return_dE = self.strategy == "gs-q"
        self.beta, self.dz_opt, self.dE = _init_beta(
            self.X_worker,
            self.D,
            self.reg,
            z_i=self.z0,
            constants=constants,
            z_positive=self.z_positive,
            return_dE=return_dE)

        # Make sure all segments are activated
        self.local_segments.reset()

        if self.z0 is not None:
            self.freezed_support = None
            self.correct_beta_z0()

        if flags.CHECK_WARM_BETA:
            pt_global = self.workers_segments.get_seg_shape(0, inner=True)
            pt = self.workers_segments.get_local_coordinate(
                self.rank, pt_global)
            if self.workers_segments.is_contained_coordinate(self.rank, pt):
                _, _, *atom_support = self.D.shape
                beta_slice = (Ellipsis, ) + tuple([
                    slice(v - size_ax + 1, v + size_ax - 1)
                    for v, size_ax in zip(pt, atom_support)
                ])
                sum_beta = np.array(self.beta[beta_slice].sum(), dtype='d')
                self.return_array(sum_beta)

        if self.freeze_support:
            assert self.z0 is not None
            self.freezed_support = self.z0 == 0
            self.dz_opt[self.freezed_support] = 0
        else:
            self.freezed_support = None

        self.synchronize_workers()

        t_local_init = time.time() - t_start
        self.info("End local initialization in {:.2f}s",
                  t_local_init,
                  global_msg=True)
예제 #6
0
    def init_cd_variables(self):
        t_start = time.time()

        # Pre-compute some quantities
        constants = {}
        constants['norm_atoms'] = compute_norm_atoms(self.D)
        if self.precomputed_DtD:
            constants['DtD'] = self.DtD
        else:
            constants['DtD'] = compute_DtD(self.D)
        self.constants = constants

        # List of all pending messages sent
        self.messages = []

        # Log all updates for logging purpose
        self._log_updates = []

        # Avoid printing progress too often
        self._last_progress = 0

        if self.warm_start and hasattr(self, 'z_hat'):
            self.z0 = self.z_hat.copy()

        # Initialization of the auxillary variable for LGCD
        self.beta, self.dz_opt, self.dE = _init_beta(
            self.X_worker,
            self.D,
            self.reg,
            z_i=self.z0,
            constants=constants,
            z_positive=self.z_positive,
            return_dE=self.strategy == "gs-q")

        # Make sure all segments are activated
        self.local_segments.reset()

        if self.z0 is not None:
            self.freezed_support = None
            self.z_hat = self.z0.copy()
            self.correct_beta_z0()
        else:
            self.z_hat = np.zeros(self.beta.shape)

        if flags.CHECK_WARM_BETA:
            worker_check_beta(self.rank, self.workers_segments, self.beta,
                              self.D.shape)

        if self.freeze_support:
            assert self.z0 is not None
            self.freezed_support = self.z0 == 0
            self.dz_opt[self.freezed_support] = 0
        else:
            self.freezed_support = None

        self.synchronize_workers(with_main=False)

        t_local_init = time.time() - t_start
        self.debug("End local initialization in {:.2f}s",
                   t_local_init,
                   global_msg=True)

        self.info(
            "Start DICOD with {} workers, strategy '{}', soft_lock"
            "={} and n_seg={}({})",
            self.n_workers,
            self.strategy,
            self.soft_lock,
            self.n_seg,
            self.local_segments.effective_n_seg,
            global_msg=True)
        return t_local_init