Beispiel #1
0
    def recv_signal(self):

        n_atoms, n_channels, *atom_support = self.D.shape

        comm = MPI.Comm.Get_parent()
        X_info = comm.bcast(None, root=0)
        self.has_z0 = X_info['has_z0']
        self.valid_support = X_info['valid_support']
        self.workers_topology = X_info['workers_topology']
        self.size_msg = len(self.workers_topology) + 2

        self.workers_segments = Segmentation(n_seg=self.workers_topology,
                                             signal_support=self.valid_support,
                                             overlap=self.overlap)

        # Receive X and z from the master node.
        worker_support = self.workers_segments.get_seg_support(self.rank)
        X_shape = (n_channels, ) + get_full_support(worker_support,
                                                    atom_support)
        z0_shape = (n_atoms, ) + worker_support
        if self.has_z0:
            z0 = self.recv_array(z0_shape)
        else:
            z0 = None
        X_worker = self.recv_array(X_shape)

        # Compute the local segmentation for LGCD algorithm

        # If n_seg is not specified, compute the shape of the local segments
        # as the size of an interfering zone.
        n_atoms, _, *atom_support = self.D.shape
        n_seg = self.n_seg
        local_seg_support = None
        if self.n_seg == 'auto':
            n_seg = None
            local_seg_support = 2 * np.array(atom_support) - 1

        # Get local inner bounds. First, compute the seg_bound without overlap
        # in local coordinates and then convert the bounds in the local
        # coordinate system.
        inner_bounds = self.workers_segments.get_seg_bounds(self.rank,
                                                            inner=True)
        inner_bounds = np.transpose([
            self.workers_segments.get_local_coordinate(self.rank, bound)
            for bound in np.transpose(inner_bounds)
        ])

        worker_support = self.workers_segments.get_seg_support(self.rank)
        self.local_segments = Segmentation(n_seg=n_seg,
                                           seg_support=local_seg_support,
                                           inner_bounds=inner_bounds,
                                           full_support=worker_support)

        self.synchronize_workers(with_main=True)

        return X_worker, z0
Beispiel #2
0
def test_padding_to_overlap():
    n_seg = (4, 4)
    sig_shape = (504, 504)
    overlap = (12, 7)

    seg = Segmentation(n_seg=n_seg, signal_shape=sig_shape, overlap=overlap)
    seg_shape_all = seg.get_seg_shape(n_seg[1] + 1)
    for i_seg in range(np.prod(n_seg)):
        seg_shape = seg.get_seg_shape(i_seg)
        z = np.empty(seg_shape)
        overlap = seg.get_padding_to_overlap(i_seg)
        z = np.pad(z, overlap, mode='constant')
        assert z.shape == seg_shape_all
Beispiel #3
0
def test_change_coordinate():
    sig_support = (505, 407)
    overlap = (12, 7)
    n_seg = (4, 4)
    segments = Segmentation(n_seg=n_seg, signal_support=sig_support,
                            overlap=overlap)

    for i_seg in range(segments.effective_n_seg):
        seg_bound = segments.get_seg_bounds(i_seg)
        seg_support = segments.get_seg_support(i_seg)
        origin = tuple([start for start, _ in seg_bound])
        assert segments.get_global_coordinate(i_seg, (0, 0)) == origin
        assert segments.get_local_coordinate(i_seg, origin) == (0, 0)

        corner = tuple([end for _, end in seg_bound])
        assert segments.get_global_coordinate(i_seg, seg_support) == corner
        assert segments.get_local_coordinate(i_seg, corner) == seg_support
Beispiel #4
0
def test_segmentation_coverage_overlap():
    sig_support = (505, 407)

    for overlap in [(3, 0), (0, 5), (3, 5), (12, 7)]:
        for h_seg in [5, 7, 9, 13, 15, 17]:
            for w_seg in [3, 11]:
                segments = Segmentation(n_seg=(h_seg, w_seg),
                                        signal_support=sig_support,
                                        overlap=overlap)
                z = np.zeros(sig_support)
                for i_seg in range(segments.effective_n_seg):
                    seg_slice = segments.get_seg_slice(i_seg, inner=True)
                    z[seg_slice] += 1
                    i_seg = segments.increment_seg(i_seg)
                non_overlapping = np.prod(sig_support)
                assert np.sum(z == 1) == non_overlapping

                z = np.zeros(sig_support)
                for i_seg in range(segments.effective_n_seg):
                    seg_slice = segments.get_seg_slice(i_seg)
                    z[seg_slice] += 1
                    i_seg = segments.increment_seg(i_seg)

                h_ov, w_ov = overlap
                h_seg, w_seg = segments.n_seg_per_axis
                expected_overlap = ((h_seg - 1) * sig_support[1] * 2 * h_ov)
                expected_overlap += ((w_seg - 1) * sig_support[0] * 2 * w_ov)

                # Compute the number of pixel where there is more than 2
                # segments overlappping.
                corner_overlap = 4 * (h_seg - 1) * (w_seg - 1) * h_ov * w_ov
                expected_overlap -= 2 * corner_overlap

                non_overlapping -= expected_overlap + corner_overlap
                assert non_overlapping == np.sum(z == 1)
                assert expected_overlap == np.sum(z == 2)
                assert corner_overlap == np.sum(z == 4)
Beispiel #5
0
    def recv_task(self):
        # Retrieve different constants from the base communicator and store
        # then in the class.
        params = self.get_params()

        if self.timeout:
            self.timeout *= 3

        self.random_state = params['random_state']
        if isinstance(self.random_state, int):
            self.random_state += self.rank

        self.size_msg = len(params['valid_shape']) + 2

        # Compute the shape of the worker segment.
        self.D = self.get_D()
        n_atoms, n_channels, *atom_support = self.D.shape
        self.overlap = np.array(atom_support) - 1
        self.workers_segments = Segmentation(
            n_seg=params['workers_topology'],
            signal_shape=params['valid_shape'],
            overlap=self.overlap)

        # Receive X and z from the master node.
        worker_shape = self.workers_segments.get_seg_shape(self.rank)
        X_shape = (n_channels, ) + get_full_shape(worker_shape, atom_support)
        if params['has_z0']:
            z0_shape = (n_atoms, ) + worker_shape
            self.z0 = self.get_signal(z0_shape, params['debug'])
        else:
            self.z0 = None
        self.X_worker = self.get_signal(X_shape, params['debug'])

        # Compute the local segmentation for LGCD algorithm

        # If n_seg is not specified, compute the shape of the local segments
        # as the size of an interfering zone.
        n_seg = self.n_seg
        local_seg_shape = None
        if self.n_seg == 'auto':
            n_seg = None
            local_seg_shape = 2 * np.array(atom_support) - 1

        # Get local inner bounds. First, compute the seg_bound without overlap
        # in local coordinates and then convert the bounds in the local
        # coordinate system.
        inner_bounds = self.workers_segments.get_seg_bounds(self.rank,
                                                            inner=True)
        inner_bounds = np.transpose([
            self.workers_segments.get_local_coordinate(self.rank, bound)
            for bound in np.transpose(inner_bounds)
        ])

        self.local_segments = Segmentation(n_seg=n_seg,
                                           seg_shape=local_seg_shape,
                                           inner_bounds=inner_bounds,
                                           full_shape=worker_shape)

        # Initialize the solution
        n_atoms = self.D.shape[0]
        seg_shape = self.workers_segments.get_seg_shape(self.rank)
        if self.z0 is None:
            self.z_hat = np.zeros((n_atoms, ) + seg_shape)
        else:
            self.z_hat = self.z0

        self.info(
            "Start DICOD with {} workers, strategy '{}', soft_lock"
            "={} and n_seg={}({})",
            self.n_jobs,
            self.strategy,
            self.soft_lock,
            self.n_seg,
            self.local_segments.effective_n_seg,
            global_msg=True)

        self.synchronize_workers()
Beispiel #6
0
class DICODWorker:
    """Worker for DICOD, running LGCD locally and using MPI for communications

    Parameters
    ----------
    backend: str
        Backend used to communicate between workers. Available backends are
        { 'mpi' }.
    """
    def __init__(self, backend):
        self._backend = backend

    def recv_task(self):
        # Retrieve different constants from the base communicator and store
        # then in the class.
        params = self.get_params()

        if self.timeout:
            self.timeout *= 3

        self.random_state = params['random_state']
        if isinstance(self.random_state, int):
            self.random_state += self.rank

        self.size_msg = len(params['valid_shape']) + 2

        # Compute the shape of the worker segment.
        self.D = self.get_D()
        n_atoms, n_channels, *atom_support = self.D.shape
        self.overlap = np.array(atom_support) - 1
        self.workers_segments = Segmentation(
            n_seg=params['workers_topology'],
            signal_shape=params['valid_shape'],
            overlap=self.overlap)

        # Receive X and z from the master node.
        worker_shape = self.workers_segments.get_seg_shape(self.rank)
        X_shape = (n_channels, ) + get_full_shape(worker_shape, atom_support)
        if params['has_z0']:
            z0_shape = (n_atoms, ) + worker_shape
            self.z0 = self.get_signal(z0_shape, params['debug'])
        else:
            self.z0 = None
        self.X_worker = self.get_signal(X_shape, params['debug'])

        # Compute the local segmentation for LGCD algorithm

        # If n_seg is not specified, compute the shape of the local segments
        # as the size of an interfering zone.
        n_seg = self.n_seg
        local_seg_shape = None
        if self.n_seg == 'auto':
            n_seg = None
            local_seg_shape = 2 * np.array(atom_support) - 1

        # Get local inner bounds. First, compute the seg_bound without overlap
        # in local coordinates and then convert the bounds in the local
        # coordinate system.
        inner_bounds = self.workers_segments.get_seg_bounds(self.rank,
                                                            inner=True)
        inner_bounds = np.transpose([
            self.workers_segments.get_local_coordinate(self.rank, bound)
            for bound in np.transpose(inner_bounds)
        ])

        self.local_segments = Segmentation(n_seg=n_seg,
                                           seg_shape=local_seg_shape,
                                           inner_bounds=inner_bounds,
                                           full_shape=worker_shape)

        # Initialize the solution
        n_atoms = self.D.shape[0]
        seg_shape = self.workers_segments.get_seg_shape(self.rank)
        if self.z0 is None:
            self.z_hat = np.zeros((n_atoms, ) + seg_shape)
        else:
            self.z_hat = self.z0

        self.info(
            "Start DICOD with {} workers, strategy '{}', soft_lock"
            "={} and n_seg={}({})",
            self.n_jobs,
            self.strategy,
            self.soft_lock,
            self.n_seg,
            self.local_segments.effective_n_seg,
            global_msg=True)

        self.synchronize_workers()

    def compute_z_hat(self):

        # Initialization of the algorithm variables
        random_state = check_random_state(self.random_state)
        i_seg = -1
        n_coordinate_updates = 0
        accumulator = 0
        k0, pt0 = 0, None
        self.n_paused_worker = 0

        # compute the number of coordinates
        n_atoms, *_ = self.D.shape
        seg_in_shape = self.workers_segments.get_seg_shape(self.rank,
                                                           inner=True)
        n_coordinates = n_atoms * np.prod(seg_in_shape)

        self.init_cd_variables()

        diverging = False
        if flags.INTERACTIVE_PROCESSES and self.n_jobs == 1:
            import ipdb
            ipdb.set_trace()  # noqa: E702
        self.t_start = t_start = time.time()
        if self.timeout is not None:
            deadline = t_start + self.timeout
        else:
            deadline = None
        for ii in range(self.max_iter):
            # Display the progress of the algorithm
            self.progress(ii, max_ii=self.max_iter, unit="iterations")

            # Process incoming messages
            self.process_messages()

            # Increment the segment and select the coordinate to update
            try:
                i_seg = self.local_segments.increment_seg(i_seg)
            except ZeroDivisionError:
                print(self.local_segments.signal_shape,
                      self.local_segments.n_seg_per_axis)
                raise
            if self.local_segments.is_active_segment(i_seg):
                k0, pt0, dz = _select_coordinate(self.dz_opt,
                                                 self.dE,
                                                 self.local_segments,
                                                 i_seg,
                                                 strategy=self.strategy,
                                                 random_state=random_state)

                assert self.workers_segments.is_contained_coordinate(
                    self.rank, pt0, inner=True), pt0
            else:
                k0, pt0, dz = None, None, 0

            # update the accumulator for 'random' strategy
            accumulator = max(abs(dz), accumulator)

            # If requested, check that the update chosen only have an impact on
            # the segment and its overlap area.
            if flags.CHECK_UPDATE_CONTAINED and pt0 is not None:
                self.workers_segments.check_area_contained(
                    self.rank, pt0, self.overlap)

            # Check if the coordinate is soft-locked or not.
            soft_locked = False
            if (pt0 is not None and abs(dz) > self.tol
                    and self.soft_lock != 'none'):
                n_lock = 1 if self.soft_lock == "corner" else 0
                lock_slices = self.workers_segments.get_touched_overlap_slices(
                    self.rank, pt0,
                    np.array(self.overlap) + 1)
                # Only soft lock in the corners
                if len(lock_slices) > n_lock:
                    max_on_lock = 0
                    for u_slice in lock_slices:
                        max_on_lock = max(
                            abs(self.dz_opt[u_slice]).max(), max_on_lock)
                    soft_locked = max_on_lock > abs(dz)

            # Update the selected coordinate and beta, only if the update is
            # greater than the convergence tolerance and is contained in the
            # worker. If the update is not in the worker, this will
            # effectively work has a soft lock to prevent interferences.
            if abs(dz) > self.tol and not soft_locked:
                n_coordinate_updates += 1

                # update the selected coordinate and beta
                self.coordinate_update(k0, pt0, dz)

                # Notify neighboring workers of the update if needed.
                pt_global = self.workers_segments.get_global_coordinate(
                    self.rank, pt0)
                workers = self.workers_segments.get_touched_segments(
                    pt=pt_global, radius=np.array(self.overlap) + 1)
                msg = np.array([k0, *pt_global, dz], 'd')

                self.notify_neighbors(msg, workers)

                if self.timing:
                    t_update = time.time() - t_start
                    self._log_updates.append(
                        (t_update, ii, self.rank, k0, pt_global, dz))

            # Inactivate the current segment if the magnitude of the update is
            # too small. This only work when using LGCD.
            if abs(dz) <= self.tol and self.strategy == "greedy":
                self.local_segments.set_inactive_segments(i_seg)

            # When workers are diverging, finish the worker to avoid having to
            # wait until max_iter for stopping the algorithm.
            if abs(dz) >= 1e3:
                self.info("diverging worker")
                self.wait_status_changed(status=constants.STATUS_FINISHED)
                diverging = True
                break

            # Check the stopping criterion and if we have locally converged,
            # wait either for an incoming message or for full convergence.
            if _check_convergence(self.local_segments,
                                  self.tol,
                                  ii,
                                  self.dz_opt,
                                  n_coordinates,
                                  self.strategy,
                                  accumulator=accumulator):

                if flags.CHECK_ACTIVE_SEGMENTS:
                    inner_slice = (Ellipsis, ) + tuple([
                        slice(start, end)
                        for start, end in self.local_segments.inner_bounds
                    ])
                    assert np.all(abs(self.dz_opt[inner_slice]) <= self.tol)
                if self.check_no_transitting_message():
                    status = self.wait_status_changed()
                    if status == constants.STATUS_STOP:
                        self.debug(
                            "LGCD converged with {} coordinate "
                            "updates", n_coordinate_updates)
                        break

            # Check is we reach the timeout
            if deadline is not None and time.time() >= deadline:
                self.stop_before_convergence("Reached timeout",
                                             n_coordinate_updates)
                break
        else:
            self.stop_before_convergence("Reached max_iter",
                                         n_coordinate_updates)

        self.synchronize_workers()
        assert diverging or self.check_no_transitting_message()
        runtime = time.time() - t_start

        comm = MPI.Comm.Get_parent()
        comm.gather([n_coordinate_updates, runtime], root=0)

        return n_coordinate_updates, runtime

    def run(self):
        self.recv_task()
        n_coordinate_updates, runtime = self.compute_z_hat()
        self.send_result(n_coordinate_updates, runtime)

    def stop_before_convergence(self, msg, n_coordinate_updates):
        self.info("{}. Done {} coordinate updates. Max of |dz|={}.", msg,
                  n_coordinate_updates,
                  abs(self.dz_opt).max())
        self.wait_status_changed(status=constants.STATUS_FINISHED)

    def init_cd_variables(self):
        t_start = time.time()

        # Pre-compute some quantities
        constants = {}
        constants['norm_atoms'] = compute_norm_atoms(self.D)
        constants['DtD'] = compute_DtD(self.D)
        self.constants = constants

        # List of all pending messages sent
        self.messages = []

        # Log all updates for logging purpose
        self._log_updates = []

        # Avoid printing progress too often
        self._last_progress = 0

        # Initialization of the auxillary variable for LGCD
        return_dE = self.strategy == "gs-q"
        self.beta, self.dz_opt, self.dE = _init_beta(
            self.X_worker,
            self.D,
            self.reg,
            z_i=self.z0,
            constants=constants,
            z_positive=self.z_positive,
            return_dE=return_dE)

        # Make sure all segments are activated
        self.local_segments.reset()

        if self.z0 is not None:
            self.freezed_support = None
            self.correct_beta_z0()

        if flags.CHECK_WARM_BETA:
            pt_global = self.workers_segments.get_seg_shape(0, inner=True)
            pt = self.workers_segments.get_local_coordinate(
                self.rank, pt_global)
            if self.workers_segments.is_contained_coordinate(self.rank, pt):
                _, _, *atom_support = self.D.shape
                beta_slice = (Ellipsis, ) + tuple([
                    slice(v - size_ax + 1, v + size_ax - 1)
                    for v, size_ax in zip(pt, atom_support)
                ])
                sum_beta = np.array(self.beta[beta_slice].sum(), dtype='d')
                self.return_array(sum_beta)

        if self.freeze_support:
            assert self.z0 is not None
            self.freezed_support = self.z0 == 0
            self.dz_opt[self.freezed_support] = 0
        else:
            self.freezed_support = None

        self.synchronize_workers()

        t_local_init = time.time() - t_start
        self.info("End local initialization in {:.2f}s",
                  t_local_init,
                  global_msg=True)

    def coordinate_update(self, k0, pt0, dz, coordinate_exist=True):
        self.beta, self.dz_opt, self.dE = coordinate_update(
            k0,
            pt0,
            dz,
            beta=self.beta,
            dz_opt=self.dz_opt,
            dE=self.dE,
            z_hat=self.z_hat,
            D=self.D,
            reg=self.reg,
            constants=self.constants,
            z_positive=self.z_positive,
            freezed_support=self.freezed_support,
            coordinate_exist=coordinate_exist)

        # Re-activate the segments where beta have been updated to ensure
        # convergence.
        touched_segments = self.local_segments.get_touched_segments(
            pt=pt0, radius=self.overlap)
        n_changed_status = self.local_segments.set_active_segments(
            touched_segments)

        # If requested, check that all inactive segments have no coefficients
        # to update over the tolerance.
        if flags.CHECK_ACTIVE_SEGMENTS and n_changed_status:
            self.local_segments.test_active_segments(self.dz_opt, self.tol)

    def process_messages(self, worker_status=constants.STATUS_RUNNING):
        mpi_status = MPI.Status()
        while MPI.COMM_WORLD.Iprobe(status=mpi_status):
            src = mpi_status.source
            tag = mpi_status.tag
            if tag == constants.TAG_DICOD_UPDATE_BETA:
                if worker_status == constants.STATUS_PAUSED:
                    self.notify_worker_status(
                        constants.TAG_DICOD_RUNNING_WORKER, wait=True)
                    worker_status = constants.STATUS_RUNNING
            elif tag == constants.TAG_DICOD_STOP:
                worker_status = constants.STATUS_STOP
            elif tag == constants.TAG_DICOD_PAUSED_WORKER:
                self.n_paused_worker += 1
                assert self.n_paused_worker <= self.n_jobs
            elif tag == constants.TAG_DICOD_RUNNING_WORKER:
                self.n_paused_worker -= 1
                assert self.n_paused_worker >= 0

            msg = np.empty(self.size_msg, 'd')
            MPI.COMM_WORLD.Recv([msg, MPI.DOUBLE], source=src, tag=tag)

            if tag == constants.TAG_DICOD_UPDATE_BETA:
                self.message_update_beta(msg)

        if self.n_paused_worker == self.n_jobs:
            worker_status = constants.STATUS_STOP
        return worker_status

    def message_update_beta(self, msg):
        k0, *pt_global, dz = msg

        k0 = int(k0)
        pt_global = tuple([int(v) for v in pt_global])
        pt0 = self.workers_segments.get_local_coordinate(self.rank, pt_global)
        assert not self.workers_segments.is_contained_coordinate(
            self.rank, pt0, inner=True), (pt_global, pt0)
        coordinate_exist = self.workers_segments.is_contained_coordinate(
            self.rank, pt0, inner=False)
        self.coordinate_update(k0, pt0, dz, coordinate_exist=coordinate_exist)

        if flags.CHECK_BETA and np.random.rand() > 0.99:
            # Only check beta 1% of the time to avoid the check being too long
            inner_slice = (Ellipsis, ) + tuple([
                slice(start, end)
                for start, end in self.local_segments.inner_bounds
            ])
            beta, *_ = _init_beta(self.X_worker,
                                  self.D,
                                  self.reg,
                                  z_i=self.z_hat,
                                  constants=self.constants,
                                  z_positive=self.z_positive)
            assert np.allclose(beta[inner_slice], self.beta[inner_slice])

    def notify_neighbors(self, msg, neighbors):
        assert self.rank in neighbors
        for i_neighbor in neighbors:
            if i_neighbor != self.rank:
                req = self.send_message(msg,
                                        constants.TAG_DICOD_UPDATE_BETA,
                                        i_neighbor,
                                        wait=False)
                self.messages.append(req)

    def notify_worker_status(self, tag, i_worker=0, wait=False):
        # handle the messages from Worker0 to himself.
        if self.rank == 0 and i_worker == 0:
            if tag == constants.TAG_DICOD_PAUSED_WORKER:
                self.n_paused_worker += 1
                assert self.n_paused_worker <= self.n_jobs
            elif tag == constants.TAG_DICOD_RUNNING_WORKER:
                self.n_paused_worker -= 1
                assert self.n_paused_worker >= 0
            elif tag == constants.TAG_DICOD_INIT_DONE:
                pass
            else:
                raise ValueError("Got tag {}".format(tag))
            return

        # Else send the message to the required destination
        msg = np.empty(self.size_msg, 'd')
        self.send_message(msg, tag, i_worker, wait=wait)

    def wait_status_changed(self, status=constants.STATUS_PAUSED):
        if status == constants.STATUS_FINISHED:
            # Make sure to flush the messages
            while not self.check_no_transitting_message():
                self.process_messages(worker_status=status)
                time.sleep(0.001)

        self.notify_worker_status(constants.TAG_DICOD_PAUSED_WORKER)
        self.debug("paused worker")

        # Wait for all sent message to be processed
        count = 0
        while status not in [constants.STATUS_RUNNING, constants.STATUS_STOP]:
            time.sleep(.005)
            status = self.process_messages(worker_status=status)
            if (count % 500) == 0:
                self.progress(self.n_paused_worker,
                              max_ii=self.n_jobs,
                              unit="done workers")

        if self.rank == 0 and status == constants.STATUS_STOP:
            for i_worker in range(1, self.n_jobs):
                self.notify_worker_status(constants.TAG_DICOD_STOP,
                                          i_worker,
                                          wait=True)
        elif status == constants.STATUS_RUNNING:
            self.debug("wake up")
        else:
            assert status == constants.STATUS_STOP
        return status

    def compute_sufficient_statistics(self):
        _, _, *atom_support = self.D.shape
        z_slice = (Ellipsis, ) + tuple([
            slice(start, end)
            for start, end in self.local_segments.inner_bounds
        ])
        X_slice = (Ellipsis, ) + tuple([
            slice(start, end + size_atom_ax - 1)
            for (start, end), size_atom_ax in zip(
                self.local_segments.inner_bounds, atom_support)
        ])

        ztX = compute_ztX(self.z_hat[z_slice], self.X_worker[X_slice])

        padding_shape = self.workers_segments.get_padding_to_overlap(self.rank)
        ztz = compute_ztz(self.z_hat,
                          atom_support,
                          padding_shape=padding_shape)
        return np.array(ztz, dtype='d'), np.array(ztX, dtype='d')

    def correct_beta_z0(self):
        # Send coordinate updates to neighbors for all nonzero coordinates in
        # z0
        msg_send, msg_recv = [0] * self.n_jobs, [0] * self.n_jobs
        for k0, *pt0 in zip(*self.z0.nonzero()):
            # Notify neighboring workers of the update if needed.
            pt_global = self.workers_segments.get_global_coordinate(
                self.rank, pt0)
            workers = self.workers_segments.get_touched_segments(
                pt=pt_global, radius=np.array(self.overlap) + 1)
            msg = np.array([k0, *pt_global, self.z0[(k0, *pt0)]], 'd')
            self.notify_neighbors(msg, workers)
            for i in workers:
                msg_send[i] += 1

        n_init_done = 0
        done_pt = set()
        no_msg, init_done = False, False
        mpi_status = MPI.Status()
        while not init_done:
            if n_init_done == self.n_jobs:
                for i_worker in range(1, self.n_jobs):
                    self.notify_worker_status(constants.TAG_DICOD_INIT_DONE,
                                              i_worker=i_worker)
                init_done = True
            if not no_msg:
                if self.check_no_transitting_message(check_incoming=False):
                    self.notify_worker_status(constants.TAG_DICOD_INIT_DONE)
                    if self.rank == 0:
                        n_init_done += 1
                    assert len(self.messages) == 0
                    no_msg = True

            if MPI.COMM_WORLD.Iprobe(status=mpi_status):
                tag = mpi_status.tag
                src = mpi_status.source
                if tag == constants.TAG_DICOD_INIT_DONE:
                    if self.rank == 0:
                        n_init_done += 1
                    else:
                        init_done = True

                msg = np.empty(self.size_msg, 'd')
                MPI.COMM_WORLD.Recv([msg, MPI.DOUBLE], source=src, tag=tag)

                if tag == constants.TAG_DICOD_UPDATE_BETA:
                    msg_recv[src] += 1
                    k0, *pt_global, dz = msg
                    k0 = int(k0)
                    pt_global = tuple([int(v) for v in pt_global])
                    pt0 = self.workers_segments.get_local_coordinate(
                        self.rank, pt_global)
                    pt_exist = self.workers_segments.is_contained_coordinate(
                        self.rank, pt0, inner=False)
                    if not pt_exist and (k0, *pt0) not in done_pt:
                        done_pt.add((k0, *pt0))
                        self.coordinate_update(k0,
                                               pt0,
                                               dz,
                                               coordinate_exist=False)

            else:
                time.sleep(.001)

    def compute_cost(self):
        X_hat_worker = reconstruct(self.z_hat, self.D)
        inner_bounds = self.local_segments.inner_bounds
        inner_slice = tuple(
            [Ellipsis] +
            [slice(start_ax, end_ax) for start_ax, end_ax in inner_bounds])
        X_hat_slice = list(inner_slice)
        i_seg = self.rank
        ax_rank_offset = self.workers_segments.effective_n_seg
        for ax, n_seg_ax in enumerate(self.workers_segments.n_seg_per_axis):
            ax_rank_offset //= n_seg_ax
            ax_i_seg = i_seg // ax_rank_offset
            i_seg % ax_rank_offset
            if (ax_i_seg + 1) % n_seg_ax == 0:
                s = inner_slice[ax + 1]
                X_hat_slice[ax + 1] = slice(s.start, None)
        X_hat_slice = tuple(X_hat_slice)
        diff = (X_hat_worker[X_hat_slice] - self.X_worker[X_hat_slice]).ravel()
        cost = .5 * np.dot(diff, diff)
        return cost + self.reg * abs(self.z_hat[inner_slice]).sum()

    def return_z_hat(self):
        if flags.GET_OVERLAP_Z_HAT:
            res_slice = (Ellipsis, )
        else:
            res_slice = (Ellipsis, ) + tuple([
                slice(start, end)
                for start, end in self.local_segments.inner_bounds
            ])
        z_worker = self.z_hat[res_slice].ravel()
        self.return_array(z_worker)

    def return_z_nnz(self):
        res_slice = (Ellipsis, ) + tuple([
            slice(start, end)
            for start, end in self.local_segments.inner_bounds
        ])
        z_nnz = self.z_hat[res_slice] != 0
        z_nnz = np.sum(z_nnz, axis=tuple(range(1, z_nnz.ndim)))
        self.reduce_sum_array(z_nnz)

    def return_sufficient_statistics(self):
        ztz, ztX = self.compute_sufficient_statistics()
        self.reduce_sum_array(ztz)
        self.reduce_sum_array(ztX)

    def return_cost(self):
        cost = self.compute_cost()
        cost = np.array(cost, dtype='d')
        self.reduce_sum_array(cost)

    ###########################################################################
    #     Display utilities
    ###########################################################################

    def progress(self, ii, max_ii, unit):
        t_progress = time.time()
        if t_progress - self._last_progress < 1:
            return
        self._last_progress = t_progress
        self._log("{:.0f}s - progress : {:7.2%} {}",
                  time.time() - self.t_start,
                  ii / max_ii,
                  unit,
                  level=1,
                  level_name="PROGRESS",
                  global_msg=True,
                  endline=False)

    def info(self, msg, *fmt_args, global_msg=False, **fmt_kwargs):
        self._log(msg,
                  *fmt_args,
                  level=1,
                  level_name="INFO",
                  global_msg=global_msg,
                  **fmt_kwargs)

    def debug(self, msg, *fmt_args, global_msg=False, **fmt_kwargs):
        self._log(msg,
                  *fmt_args,
                  level=5,
                  level_name="DEBUG",
                  global_msg=global_msg,
                  **fmt_kwargs)

    def _log(self,
             msg,
             *fmt_args,
             level=0,
             level_name="None",
             global_msg=False,
             endline=True,
             **fmt_kwargs):
        if self.verbose >= level:
            if global_msg:
                if self.rank != 0:
                    return
                msg_fmt = constants.GLOBAL_OUTPUT_TAG + msg
                identity = self.n_jobs
            else:
                msg_fmt = constants.WORKER_OUTPUT_TAG + msg
                identity = self.rank
            if endline:
                kwargs = {}
            else:
                kwargs = {'end': '', 'flush': True}
            msg_fmt = msg_fmt.ljust(80)
            print(
                msg_fmt.format(
                    *fmt_args,
                    identity=identity,
                    level_name=level_name,
                    **fmt_kwargs,
                ), **kwargs)

    ###########################################################################
    #     Communication primitives
    ###########################################################################

    def synchronize_workers(self):
        if self._backend == "mpi":
            self._synchronize_workers_mpi()
        else:
            raise NotImplementedError("Backend {} is not implemented".format(
                self._backend))

    def get_params(self):
        """Receive the parameter of the algorithm from the master node."""
        if self._backend == "mpi":
            self.rank, self.n_jobs, params = self._get_params_mpi()
        else:
            raise NotImplementedError("Backend {} is not implemented".format(
                self._backend))

        self.tol = params['tol']
        self.reg = params['reg']
        self.n_seg = params['n_seg']
        self.timing = params['timing']
        self.timeout = params['timeout']
        self.verbose = params['verbose']
        self.strategy = params['strategy']
        self.max_iter = params['max_iter']
        self.soft_lock = params['soft_lock']
        self.z_positive = params['z_positive']
        self.return_ztz = params['return_ztz']
        self.freeze_support = params['freeze_support']

        self.debug("tol updated to {:.2e}", self.tol, global_msg=True)
        return params

    def get_signal(self, X_shape, debug=False):
        """Receive the part of the signal to encode from the master node."""
        if self._backend == "mpi":
            return self._get_signal_mpi(X_shape, debug=debug)
        else:
            raise NotImplementedError("Backend {} is not implemented".format(
                self._backend))

    def send_message(self, msg, tag, i_worker, wait=False):
        """Send a message to a specified worker."""
        assert self.rank != i_worker
        if self._backend == "mpi":
            return self._send_message_mpi(msg, tag, i_worker, wait=wait)
        else:
            raise NotImplementedError("Backend {} is not implemented".format(
                self._backend))

    def send_result(self, iterations, runtime):
        if self._backend == "mpi":
            self._send_result_mpi(iterations, runtime)
        else:
            raise NotImplementedError("Backend {} is not implemented".format(
                self._backend))

    def get_D(self):
        """Receive a dictionary D"""
        if self._backend == "mpi":
            comm = MPI.Comm.Get_parent()
            self.D = recv_broadcasted_array(comm)
            return self.D
        else:
            raise NotImplementedError("Backend {} is not implemented".format(
                self._backend))

    def return_array(self, sig):
        if self._backend == "mpi":
            self._return_array_mpi(sig)
        else:
            raise NotImplementedError("Backend {} is not implemented".format(
                self._backend))

    def reduce_sum_array(self, arr):
        if self._backend == "mpi":
            self._reduce_sum_array_mpi(arr)
        else:
            raise NotImplementedError("Backend {} is not implemented".format(
                self._backend))

    def shutdown(self):
        if self._backend == "mpi":
            self._shutdown_mpi()
        else:
            raise NotImplementedError("Backend {} is not implemented".format(
                self._backend))

    ###########################################################################
    #     mpi4py implementation
    ###########################################################################

    def _synchronize_workers_mpi(self):
        comm = MPI.Comm.Get_parent()
        comm.Barrier()

    def check_no_transitting_message(self, check_incoming=True):
        """Check no message is in waiting to complete to or from this worker"""
        if check_incoming and MPI.COMM_WORLD.Iprobe():
            return False
        while self.messages:
            if not self.messages[0].Test() or (check_incoming
                                               and MPI.COMM_WORLD.Iprobe()):
                return False
            self.messages.pop(0)
        assert len(self.messages) == 0, len(self.messages)
        return True

    def _get_params_mpi(self):
        comm = MPI.Comm.Get_parent()

        rank = comm.Get_rank()
        n_jobs = comm.Get_size()
        params = comm.bcast(None, root=0)
        return rank, n_jobs, params

    def _get_signal_mpi(self, sig_shape, debug):
        comm = MPI.Comm.Get_parent()
        rank = comm.Get_rank()

        sig_worker = np.empty(sig_shape, dtype='d')
        comm.Recv([sig_worker.ravel(), MPI.DOUBLE],
                  source=0,
                  tag=constants.TAG_ROOT + rank)

        if debug:
            X_alpha = 0.25 * np.ones(sig_shape)
            self.return_array(X_alpha)

        return sig_worker

    def _send_message_mpi(self, msg, tag, i_worker, wait=False):
        if wait:
            return MPI.COMM_WORLD.Ssend([msg, MPI.DOUBLE], i_worker, tag=tag)
        else:
            return MPI.COMM_WORLD.Issend([msg, MPI.DOUBLE], i_worker, tag=tag)

    def _send_result_mpi(self, iterations, runtime):
        comm = MPI.Comm.Get_parent()
        self.info("Reducing the distributed results", global_msg=True)

        self.return_z_hat()

        if self.return_ztz:
            self.return_sufficient_statistics()

        self.return_cost()

        if self.timing:
            comm.send(self._log_updates, dest=0)

        comm.Barrier()

    def _return_array_mpi(self, arr):
        comm = MPI.Comm.Get_parent()
        arr.astype('d')
        comm.Send([arr, MPI.DOUBLE],
                  dest=0,
                  tag=constants.TAG_ROOT + self.rank)

    def _reduce_sum_array_mpi(self, arr):
        comm = MPI.Comm.Get_parent()
        arr = np.array(arr, dtype='d')
        comm.Reduce([arr, MPI.DOUBLE], None, op=MPI.SUM, root=0)

    def _shutdown_mpi(self):
        comm = MPI.Comm.Get_parent()
        self.debug("clean shutdown")
        comm.Barrier()
        comm.Disconnect()
Beispiel #7
0
def test_touched_segments():
    """Test detection of touched segments and records of active segments
    """
    rng = np.random.RandomState(42)

    H, W = sig_support = (108, 53)
    n_seg = (9, 3)
    for h_radius in [5, 7, 9]:
        for w_radius in [3, 11]:
            for _ in range(20):
                h0 = rng.randint(-h_radius, sig_support[0] + h_radius)
                w0 = rng.randint(-w_radius, sig_support[1] + w_radius)
                z = np.zeros(sig_support)
                segments = Segmentation(n_seg, signal_support=sig_support)

                touched_slice = (
                    slice(max(0, h0 - h_radius), min(H, h0 + h_radius + 1)),
                    slice(max(0, w0 - w_radius), min(W, w0 + w_radius + 1))
                )
                z[touched_slice] = 1

                touched_segments = segments.get_touched_segments(
                    (h0, w0), (h_radius, w_radius))
                segments.set_inactive_segments(touched_segments)
                n_active_segments = segments._n_active_segments

                expected_n_active_segments = segments.effective_n_seg
                for i_seg in range(segments.effective_n_seg):
                    seg_slice = segments.get_seg_slice(i_seg)
                    is_touched = np.any(z[seg_slice] == 1)
                    expected_n_active_segments -= is_touched

                    assert segments.is_active_segment(i_seg) != is_touched
                assert n_active_segments == expected_n_active_segments

    # Check an error is returned when touched radius is larger than seg_size
    segments = Segmentation(n_seg, signal_support=sig_support)
    with pytest.raises(ValueError, match="too large"):
        segments.get_touched_segments((0, 0), (30, 2))
Beispiel #8
0
def test_segmentation_coverage():
    sig_support = (108, 53)

    for h_seg in [5, 7, 9, 13, 17]:
        for w_seg in [3, 11]:
            z = np.zeros(sig_support)
            segments = Segmentation(n_seg=(h_seg, w_seg),
                                    signal_support=sig_support)
            assert tuple(segments.n_seg_per_axis) == (h_seg, w_seg)
            seg_slice = segments.get_seg_slice(0)
            seg_support = segments.get_seg_support(0)
            assert seg_support == z[seg_slice].shape
            z[seg_slice] += 1
            i_seg = segments.increment_seg(0)
            while i_seg != 0:
                seg_slice = segments.get_seg_slice(i_seg)
                seg_support = segments.get_seg_support(i_seg)
                assert seg_support == z[seg_slice].shape
                z[seg_slice] += 1
                i_seg = segments.increment_seg(i_seg)

            assert np.all(z == 1)

    z = np.zeros(sig_support)
    inner_bounds = [(8, 100), (3, 50)]
    inner_slice = tuple([slice(start, end) for start, end in inner_bounds])
    segments = Segmentation(n_seg=7, inner_bounds=inner_bounds,
                            full_support=sig_support)
    for i_seg in range(segments.effective_n_seg):
        seg_slice = segments.get_seg_slice(i_seg)
        z[seg_slice] += 1

    assert np.all(z[inner_slice] == 1)
    z[inner_slice] = 0
    assert np.all(z == 0)
Beispiel #9
0
def test_touched_overlap_area():
    sig_support = (505, 407)
    overlap = (11, 9)
    n_seg = (8, 4)
    segments = Segmentation(n_seg=n_seg, signal_support=sig_support,
                            overlap=overlap)

    for i_seg in range(segments.effective_n_seg):
        seg_support = segments.get_seg_support(i_seg)
        seg_slice = segments.get_seg_slice(i_seg)
        seg_inner_slice = segments.get_seg_slice(i_seg, inner=True)
        if i_seg != 0:
            with pytest.raises(AssertionError):
                segments.check_area_contained(i_seg, (0, 0), overlap)
        for pt0 in [overlap, (overlap[0], 25), (25, overlap[1]), (25, 25),
                    (seg_support[0] - overlap[0] - 1, 25),
                    (25, seg_support[1] - overlap[1] - 1),
                    (seg_support[0] - overlap[0] - 1,
                     seg_support[1] - overlap[1] - 1)
                    ]:
            assert segments.is_contained_coordinate(i_seg, pt0, inner=True)
            segments.check_area_contained(i_seg, pt0, overlap)
            z = np.zeros(sig_support)
            pt_global = segments.get_global_coordinate(i_seg, pt0)
            update_slice = tuple([
                slice(max(v - r, 0), v + r + 1)
                for v, r in zip(pt_global, overlap)])

            z[update_slice] += 1
            z[seg_inner_slice] = 0

            # The returned slice are given in local coordinates. Take the
            # segment in z to use local coordinate.
            z_seg = z[seg_slice]

            updated_slices = segments.get_touched_overlap_slices(i_seg, pt0,
                                                                 overlap)
            # Assert that all selected coordinate are indeed in the update area
            for u_slice in updated_slices:
                assert np.all(z_seg[u_slice] == 1)

            # Assert that all coordinate updated in the overlap area have been
            # selected with at least one slice.
            for u_slice in updated_slices:
                z_seg[u_slice] *= 0
            assert np.all(z == 0)
Beispiel #10
0
def test_inner_coordinate():
    sig_support = (505, 407)
    overlap = (11, 11)
    n_seg = (4, 4)
    segments = Segmentation(n_seg=n_seg, signal_support=sig_support,
                            overlap=overlap)

    for h_rank in range(n_seg[0]):
        for w_rank in range(n_seg[1]):
            i_seg = h_rank * n_seg[1] + w_rank
            seg_support = segments.get_seg_support(i_seg)
            assert segments.is_contained_coordinate(i_seg, overlap,
                                                    inner=True)

            if h_rank == 0:
                assert segments.is_contained_coordinate(i_seg, (0, overlap[1]),
                                                        inner=True)
            else:
                assert not segments.is_contained_coordinate(
                    i_seg, (overlap[0] - 1, overlap[1]), inner=True)

            if w_rank == 0:
                assert segments.is_contained_coordinate(i_seg, (overlap[0], 0),
                                                        inner=True)
            else:
                assert not segments.is_contained_coordinate(
                    i_seg, (overlap[0], overlap[1] - 1), inner=True)

            if h_rank == 0 and w_rank == 0:
                assert segments.is_contained_coordinate(i_seg, (0, 0),
                                                        inner=True)
            else:
                assert not segments.is_contained_coordinate(
                    i_seg, (overlap[0] - 1, overlap[1] - 1), inner=True)

            if h_rank == n_seg[0] - 1:
                assert segments.is_contained_coordinate(
                    i_seg,
                    (seg_support[0] - 1, seg_support[1] - overlap[1] - 1),
                    inner=True)
            else:
                assert not segments.is_contained_coordinate(
                    i_seg, (seg_support[0] - overlap[0],
                            seg_support[1] - overlap[1] - 1), inner=True)

            if w_rank == n_seg[1] - 1:
                assert segments.is_contained_coordinate(
                   i_seg,
                   (seg_support[0] - overlap[0] - 1, seg_support[1] - 1),
                   inner=True)
            else:
                assert not segments.is_contained_coordinate(
                    i_seg, (seg_support[0] - overlap[0] - 1,
                            seg_support[1] - overlap[1]), inner=True)

            if h_rank == n_seg[0] - 1 and w_rank == n_seg[1] - 1:
                assert segments.is_contained_coordinate(
                    i_seg, (seg_support[0] - 1, seg_support[1] - 1),
                    inner=True)
            else:
                assert not segments.is_contained_coordinate(
                    i_seg, (seg_support[0] - overlap[0],
                            seg_support[1] - overlap[1]), inner=True)
Beispiel #11
0
    atom_support = (16, 16)

    run_args = (n_atoms, atom_support, reg, tol, n_workers, random_state)
    if args.no_cache:
        X_hat, pobj = run_without_soft_lock.call(*run_args)[0]
    else:
        X_hat, pobj = run_without_soft_lock(*run_args)

    file_name = f"soft_lock_M{n_workers}_support{atom_support[0]}"
    np.save(f"benchmarks_results/{file_name}_X_hat.npy", X_hat)

    # Compute the worker segmentation for the image,
    n_channels, *sig_support = X_hat.shape
    valid_support = get_valid_support(sig_support, atom_support)
    workers_segments = Segmentation(n_seg=(w_world, w_world),
                                    signal_support=valid_support,
                                    overlap=0)

    fig = plt.figure("recovery")
    fig.patch.set_alpha(0)

    ax = plt.subplot()
    ax.imshow(X_hat.swapaxes(0, 2))
    for i_seg in range(workers_segments.effective_n_seg):
        seg_bounds = np.array(workers_segments.get_seg_bounds(i_seg))
        seg_bounds = seg_bounds + np.array(atom_support) / 2
        ax.vlines(seg_bounds[1], *seg_bounds[0], linestyle='--')
        ax.hlines(seg_bounds[0], *seg_bounds[1], linestyle='--')
    ax.axis('off')
    plt.tight_layout()
Beispiel #12
0
def coordinate_descent(X_i, D, reg, z0=None, DtD=None, n_seg='auto',
                       strategy='greedy', tol=1e-5, max_iter=100000,
                       timeout=None, z_positive=False, freeze_support=False,
                       return_ztz=False, timing=False,
                       random_state=None, verbose=0):
    """Coordinate Descent Algorithm for 2D convolutional sparse coding.

    Parameters
    ----------
    X_i : ndarray, shape (n_channels, *sig_support)
        Image to encode on the dictionary D
    D : ndarray, shape (n_atoms, n_channels, *atom_support)
        Current dictionary for the sparse coding
    reg : float
        Regularization parameter
    z0 : ndarray, shape (n_atoms, *valid_support) or None
        Warm start value for z_hat. If not present, z_hat is initialized to 0.
    DtD : ndarray, shape (n_atoms, n_atoms, 2 * valid_support - 1) or None
        Warm start value for DtD. If not present, it is computed on init.
    n_seg : int or 'auto'
        Number of segments to use for each dimension. If set to 'auto' use
        segments of twice the size of the dictionary.
    strategy : str in {strategies}
        Coordinate selection scheme for the coordinate descent. If set to
        'greedy'|'gs-r', the coordinate with the largest value for dz_opt is
        selected. If set to 'random, the coordinate is chosen uniformly on the
        segment. If set to 'gs-q', the value that reduce the most the cost
        function is selected. In this case, dE must holds the value of this
        cost reduction.
    tol : float
        Tolerance for the minimal update size in this algorithm.
    max_iter : int
        Maximal number of iteration run by this algorithm.
    z_positive : boolean
        If set to true, the activations are constrained to be positive.
    freeze_support : boolean
        If set to True, only update the coefficient that are non-zero in z0.
    return_ztz : boolean
        If True, returns the constants ztz and ztX, used to compute D-updates.
    timing : boolean
        If set to True, log the cost and timing information.
    random_state : None or int or RandomState
        current random state to seed the random number generator.
    verbose : int
        Verbosity level of the algorithm.

    Return
    ------
    z_hat : ndarray, shape (n_atoms, *valid_support)
        Activation associated to X_i for the given dictionary D
    """
    n_channels, *sig_support = X_i.shape
    n_atoms, n_channels, *atom_support = D.shape
    valid_support = get_valid_support(sig_support, atom_support)

    if strategy not in STRATEGIES:
        raise ValueError("'The coordinate selection strategy should be in "
                         "{}. Got '{}'.".format(STRATEGIES, strategy))

    # compute sizes for the segments for LGCD. Auto gives segments of size
    # twice the support of the atoms.
    if n_seg == 'auto':
        n_seg = np.array(valid_support) // (2 * np.array(atom_support) - 1)
        n_seg = tuple(np.maximum(1, n_seg))
    segments = Segmentation(n_seg, signal_support=valid_support)

    # Pre-compute constants for maintaining the auxillary variable beta and
    # compute the coordinate update values.
    constants = {}
    constants['norm_atoms'] = compute_norm_atoms(D)
    if DtD is None:
        constants['DtD'] = compute_DtD(D)
    else:
        constants['DtD'] = DtD

    # Initialization of the algorithm variables
    i_seg = -1
    accumulator = 0
    if z0 is None:
        z_hat = np.zeros((n_atoms,) + valid_support)
    else:
        z_hat = np.copy(z0)
    n_coordinates = z_hat.size

    # Get a random number genator from the given random_state
    rng = check_random_state(random_state)
    order = None
    if strategy in ['cyclic', 'cyclic-r', 'random']:
        order = get_order_iterator(z_hat.shape, strategy=strategy,
                                   random_state=rng)

    t_start_init = time.time()
    return_dE = strategy == "gs-q"
    beta, dz_opt, dE = _init_beta(X_i, D, reg, z_i=z0, constants=constants,
                                  z_positive=z_positive, return_dE=return_dE)
    if strategy == "gs-q":
        raise NotImplementedError("This is still WIP")

    if freeze_support:
        freezed_support = z0 == 0
        dz_opt[freezed_support] = 0
    else:
        freezed_support = None

    p_obj, next_log_iter = [], 1
    t_init = time.time() - t_start_init
    if timing:
        p_obj.append((0, t_init, 0, compute_objective(X_i, z_hat, D, reg)))

    n_coordinate_updates = 0
    t_run = 0
    t_select_coord, t_update_coord = [], []
    t_start = time.time()
    if timeout is not None:
        deadline = t_start + timeout
    else:
        deadline = None
    for ii in range(max_iter):
        if ii % 1000 == 0 and verbose > 0:
            print("\r[LGCD:PROGRESS] {:.0f}s - {:7.2%} iterations"
                  .format(t_run, ii / max_iter), end='', flush=True)

        i_seg = segments.increment_seg(i_seg)
        if segments.is_active_segment(i_seg):
            t_start_selection = time.time()
            k0, pt0, dz = _select_coordinate(dz_opt, dE, segments, i_seg,
                                             strategy=strategy, order=order)
            selection_duration = time.time() - t_start_selection
            t_select_coord.append(selection_duration)
            t_run += selection_duration
        else:
            dz = 0

        accumulator = max(abs(dz), accumulator)

        # Update the selected coordinate and beta, only if the update is
        # greater than the convergence tolerance.
        if abs(dz) > tol:
            t_start_update = time.time()

            # update the current solution estimate and beta
            beta, dz_opt, dE = coordinate_update(
                k0, pt0, dz, beta=beta, dz_opt=dz_opt, dE=dE, z_hat=z_hat, D=D,
                reg=reg, constants=constants, z_positive=z_positive,
                freezed_support=freezed_support)
            touched_segs = segments.get_touched_segments(
                pt=pt0, radius=atom_support)
            n_changed_status = segments.set_active_segments(touched_segs)

            # Logging of the time and the cost function if necessary
            update_duration = time.time() - t_start_update
            n_coordinate_updates += 1
            t_run += update_duration
            t_update_coord.append(update_duration)
            if timing and ii + 1 >= next_log_iter:
                p_obj.append((ii + 1, t_run, np.sum(t_select_coord),
                              compute_objective(X_i, z_hat, D, reg)))
                next_log_iter = next_log_iter * 1.3

            # If debug flag CHECK_ACTIVE_SEGMENTS is set, check that all
            # inactive segments should be inactive
            if flags.CHECK_ACTIVE_SEGMENTS and n_changed_status:
                segments.test_active_segment(dz_opt, tol)

        elif strategy in ["greedy", 'gs-r']:
            segments.set_inactive_segments(i_seg)

        # check stopping criterion
        if _check_convergence(segments, tol, ii, dz_opt, n_coordinates,
                              strategy, accumulator=accumulator):
            assert np.all(abs(dz_opt) <= tol)
            if verbose > 0:
                print("\r[LGCD:INFO] converged in {} iterations ({} updates)"
                      .format(ii + 1, n_coordinate_updates))

            break

        # Check is we reach the timeout
        if deadline is not None and time.time() >= deadline:
            if verbose > 0:
                print("\r[LGCD:INFO] Reached timeout. Done {} iterations "
                      "({} updates). Max of |dz|={}."
                      .format(ii + 1, n_coordinate_updates, abs(dz_opt).max()))
            break
    else:
        if verbose > 0:
            print("\r[LGCD:INFO] Reached max_iter. Done {} coordinate "
                  "updates. Max of |dz|={}."
                  .format(n_coordinate_updates, abs(dz_opt).max()))

    print(f"\r[LGCD:{strategy}] "
          f"t_select={np.mean(t_select_coord):.3e}s  "
          f"t_update={np.mean(t_update_coord):.3e}s"
          )

    runtime = time.time() - t_start
    if verbose > 0:
        print("\r[LGCD:INFO] done in {:.3f}s ({:.3f}s)"
              .format(runtime, t_run))

    ztz, ztX = None, None
    if return_ztz:
        ztz = compute_ztz(z_hat, atom_support)
        ztX = compute_ztX(z_hat, X_i)

    p_obj.append([n_coordinate_updates, t_run,
                  compute_objective(X_i, z_hat, D, reg)])

    run_statistics = dict(iterations=ii + 1, runtime=runtime, t_init=t_init,
                          t_run=t_run, n_updates=n_coordinate_updates,
                          t_select=np.mean(t_select_coord),
                          t_update=np.mean(t_update_coord))

    return z_hat, ztz, ztX, p_obj, run_statistics
Beispiel #13
0
def coordinate_descent(X_i,
                       D,
                       reg,
                       z0=None,
                       n_seg='auto',
                       strategy='greedy',
                       tol=1e-5,
                       max_iter=100000,
                       timeout=None,
                       z_positive=False,
                       freeze_support=False,
                       return_ztz=False,
                       timing=False,
                       random_state=None,
                       verbose=0):
    """Coordinate Descent Algorithm for 2D convolutional sparse coding.

    Parameters
    ----------
    X_i : ndarray, shape (n_channels, *sig_shape)
        Image to encode on the dictionary D
    z_i : ndarray, shape (n_atoms, *valid_shape)
        Warm start value for z_hat
    D : ndarray, shape (n_atoms, n_channels, *atom_shape)
        Current dictionary for the sparse coding
    reg : float
        Regularization parameter
    n_seg : int or { 'auto' }
        Number of segments to use for each dimension. If set to 'auto' use
        segments of twice the size of the dictionary.
    tol : float
        Tolerance for the minimal update size in this algorithm.
    strategy : str in { 'greedy' | 'random' | 'gs-r' | 'gs-q' }
        Coordinate selection scheme for the coordinate descent. If set to
        'greedy'|'gs-r', the coordinate with the largest value for dz_opt is
        selected. If set to 'random, the coordinate is chosen uniformly on the
        segment. If set to 'gs-q', the value that reduce the most the cost
        function is selected. In this case, dE must holds the value of this
        cost reduction.
    max_iter : int
        Maximal number of iteration run by this algorithm.
    z_positive : boolean
        If set to true, the activations are constrained to be positive.
    freeze_support : boolean
        If set to True, only update the coefficient that are non-zero in z0.
    timing : boolean
        If set to True, log the cost and timing information.
    random_state : None or int or RandomState
        current random state to seed the random number generator.
    verbose : int
        Verbosity level of the algorithm.

    Return
    ------
    z_hat : ndarray, shape (n_atoms, *valid_shape)
        Activation associated to X_i for the given dictionary D
    """
    n_channels, *sig_shape = X_i.shape
    n_atoms, n_channels, *atom_shape = D.shape
    valid_shape = tuple([
        size_ax - size_atom_ax + 1
        for size_ax, size_atom_ax in zip(sig_shape, atom_shape)
    ])

    # compute sizes for the segments for LGCD
    if n_seg == 'auto':
        n_seg = []
        for axis_size, atom_size in zip(valid_shape, atom_shape):
            n_seg.append(max(axis_size // (2 * atom_size - 1), 1))
    segments = Segmentation(n_seg, signal_shape=valid_shape)

    # Pre-compute some quantities
    constants = {}
    constants['norm_atoms'] = compute_norm_atoms(D)
    constants['DtD'] = compute_DtD(D)

    # Initialization of the algorithm variables
    i_seg = -1
    p_obj, next_cost = [], 1
    accumulator = 0
    if z0 is None:
        z_hat = np.zeros((n_atoms, ) + valid_shape)
    else:
        z_hat = np.copy(z0)
    n_coordinates = z_hat.size

    t_update = 0
    t_start_update = time.time()
    return_dE = strategy == "gs-q"
    beta, dz_opt, dE = _init_beta(X_i,
                                  D,
                                  reg,
                                  z_i=z0,
                                  constants=constants,
                                  z_positive=z_positive,
                                  return_dE=return_dE)
    if strategy == "gs-q":
        raise NotImplementedError("This is still WIP")

    if freeze_support:
        freezed_support = z0 == 0
        dz_opt[freezed_support] = 0
    else:
        freezed_support = None

    t_start = time.time()
    n_coordinate_updates = 0
    if timeout is not None:
        deadline = t_start + timeout
    else:
        deadline = None
    for ii in range(max_iter):
        if ii % 1000 == 0 and verbose > 0:
            print("\r[LGCD:PROGRESS] {:.0f}s - {:7.2%} iterations".format(
                t_update, ii / max_iter),
                  end='',
                  flush=True)

        i_seg = segments.increment_seg(i_seg)
        if segments.is_active_segment(i_seg):
            k0, pt0, dz = _select_coordinate(dz_opt,
                                             dE,
                                             segments,
                                             i_seg,
                                             strategy=strategy,
                                             random_state=random_state)
        else:
            k0, pt0, dz = None, None, 0

        accumulator = max(abs(dz), accumulator)

        # Update the selected coordinate and beta, only if the update is
        # greater than the convergence tolerance.
        if abs(dz) > tol:
            n_coordinate_updates += 1

            # update beta
            beta, dz_opt, dE = coordinate_update(
                k0,
                pt0,
                dz,
                beta=beta,
                dz_opt=dz_opt,
                dE=dE,
                z_hat=z_hat,
                D=D,
                reg=reg,
                constants=constants,
                z_positive=z_positive,
                freezed_support=freezed_support)
            touched_segs = segments.get_touched_segments(pt=pt0,
                                                         radius=atom_shape)
            n_changed_status = segments.set_active_segments(touched_segs)

            if flags.CHECK_ACTIVE_SEGMENTS and n_changed_status:
                segments.test_active_segment(dz_opt, tol)

            t_update += time.time() - t_start_update
            if timing:
                if ii >= next_cost:
                    p_obj.append(
                        (ii, t_update, compute_objective(X_i, z_hat, D, reg)))
                    next_cost = next_cost * 2
            t_start_update = time.time()
        elif strategy in ["greedy", 'gs-r']:
            segments.set_inactive_segments(i_seg)

        # check stopping criterion
        if _check_convergence(segments,
                              tol,
                              ii,
                              dz_opt,
                              n_coordinates,
                              strategy,
                              accumulator=accumulator):
            assert np.all(abs(dz_opt) <= tol)
            if verbose > 0:
                print("\r[LGCD:INFO] converged after {} iterations".format(ii +
                                                                           1))

            break

        # Check is we reach the timeout
        if deadline is not None and time.time() >= deadline:
            if verbose > 0:
                print("\r[LGCD:INFO] Reached timeout. Done {} coordinate "
                      "updates. Max of |dz|={}.".format(
                          n_coordinate_updates,
                          abs(dz_opt).max()))
            break
    else:
        if verbose > 0:
            print("\r[LGCD:INFO] Reached max_iter. Done {} coordinate "
                  "updates. Max of |dz|={}.".format(n_coordinate_updates,
                                                    abs(dz_opt).max()))

    runtime = time.time() - t_start
    if verbose > 0:
        print("\r[LGCD:INFO] done in {:.3}s".format(runtime))

    ztz, ztX = None, None
    if return_ztz:
        ztz = compute_ztz(z_hat, atom_shape)
        ztX = compute_ztX(z_hat, X_i)

    p_obj.append([n_coordinate_updates, t_update, None])

    return z_hat, ztz, ztX, p_obj, None