コード例 #1
0
def test_segmentation_coverage():
    sig_support = (108, 53)

    for h_seg in [5, 7, 9, 13, 17]:
        for w_seg in [3, 11]:
            z = np.zeros(sig_support)
            segments = Segmentation(n_seg=(h_seg, w_seg),
                                    signal_support=sig_support)
            assert tuple(segments.n_seg_per_axis) == (h_seg, w_seg)
            seg_slice = segments.get_seg_slice(0)
            seg_support = segments.get_seg_support(0)
            assert seg_support == z[seg_slice].shape
            z[seg_slice] += 1
            i_seg = segments.increment_seg(0)
            while i_seg != 0:
                seg_slice = segments.get_seg_slice(i_seg)
                seg_support = segments.get_seg_support(i_seg)
                assert seg_support == z[seg_slice].shape
                z[seg_slice] += 1
                i_seg = segments.increment_seg(i_seg)

            assert np.all(z == 1)

    z = np.zeros(sig_support)
    inner_bounds = [(8, 100), (3, 50)]
    inner_slice = tuple([slice(start, end) for start, end in inner_bounds])
    segments = Segmentation(n_seg=7, inner_bounds=inner_bounds,
                            full_support=sig_support)
    for i_seg in range(segments.effective_n_seg):
        seg_slice = segments.get_seg_slice(i_seg)
        z[seg_slice] += 1

    assert np.all(z[inner_slice] == 1)
    z[inner_slice] = 0
    assert np.all(z == 0)
コード例 #2
0
    def recv_signal(self):

        n_atoms, n_channels, *atom_support = self.D.shape

        comm = MPI.Comm.Get_parent()
        X_info = comm.bcast(None, root=0)
        self.has_z0 = X_info['has_z0']
        self.valid_support = X_info['valid_support']
        self.workers_topology = X_info['workers_topology']
        self.size_msg = len(self.workers_topology) + 2

        self.workers_segments = Segmentation(n_seg=self.workers_topology,
                                             signal_support=self.valid_support,
                                             overlap=self.overlap)

        # Receive X and z from the master node.
        worker_support = self.workers_segments.get_seg_support(self.rank)
        X_shape = (n_channels, ) + get_full_support(worker_support,
                                                    atom_support)
        z0_shape = (n_atoms, ) + worker_support
        if self.has_z0:
            z0 = self.recv_array(z0_shape)
        else:
            z0 = None
        X_worker = self.recv_array(X_shape)

        # Compute the local segmentation for LGCD algorithm

        # If n_seg is not specified, compute the shape of the local segments
        # as the size of an interfering zone.
        n_atoms, _, *atom_support = self.D.shape
        n_seg = self.n_seg
        local_seg_support = None
        if self.n_seg == 'auto':
            n_seg = None
            local_seg_support = 2 * np.array(atom_support) - 1

        # Get local inner bounds. First, compute the seg_bound without overlap
        # in local coordinates and then convert the bounds in the local
        # coordinate system.
        inner_bounds = self.workers_segments.get_seg_bounds(self.rank,
                                                            inner=True)
        inner_bounds = np.transpose([
            self.workers_segments.get_local_coordinate(self.rank, bound)
            for bound in np.transpose(inner_bounds)
        ])

        worker_support = self.workers_segments.get_seg_support(self.rank)
        self.local_segments = Segmentation(n_seg=n_seg,
                                           seg_support=local_seg_support,
                                           inner_bounds=inner_bounds,
                                           full_support=worker_support)

        self.synchronize_workers(with_main=True)

        return X_worker, z0
コード例 #3
0
def test_touched_segments():
    """Test detection of touched segments and records of active segments
    """
    rng = np.random.RandomState(42)

    H, W = sig_support = (108, 53)
    n_seg = (9, 3)
    for h_radius in [5, 7, 9]:
        for w_radius in [3, 11]:
            for _ in range(20):
                h0 = rng.randint(-h_radius, sig_support[0] + h_radius)
                w0 = rng.randint(-w_radius, sig_support[1] + w_radius)
                z = np.zeros(sig_support)
                segments = Segmentation(n_seg, signal_support=sig_support)

                touched_slice = (
                    slice(max(0, h0 - h_radius), min(H, h0 + h_radius + 1)),
                    slice(max(0, w0 - w_radius), min(W, w0 + w_radius + 1))
                )
                z[touched_slice] = 1

                touched_segments = segments.get_touched_segments(
                    (h0, w0), (h_radius, w_radius))
                segments.set_inactive_segments(touched_segments)
                n_active_segments = segments._n_active_segments

                expected_n_active_segments = segments.effective_n_seg
                for i_seg in range(segments.effective_n_seg):
                    seg_slice = segments.get_seg_slice(i_seg)
                    is_touched = np.any(z[seg_slice] == 1)
                    expected_n_active_segments -= is_touched

                    assert segments.is_active_segment(i_seg) != is_touched
                assert n_active_segments == expected_n_active_segments

    # Check an error is returned when touched radius is larger than seg_size
    segments = Segmentation(n_seg, signal_support=sig_support)
    with pytest.raises(ValueError, match="too large"):
        segments.get_touched_segments((0, 0), (30, 2))
コード例 #4
0
def test_padding_to_overlap():
    n_seg = (4, 4)
    sig_shape = (504, 504)
    overlap = (12, 7)

    seg = Segmentation(n_seg=n_seg, signal_shape=sig_shape, overlap=overlap)
    seg_shape_all = seg.get_seg_shape(n_seg[1] + 1)
    for i_seg in range(np.prod(n_seg)):
        seg_shape = seg.get_seg_shape(i_seg)
        z = np.empty(seg_shape)
        overlap = seg.get_padding_to_overlap(i_seg)
        z = np.pad(z, overlap, mode='constant')
        assert z.shape == seg_shape_all
コード例 #5
0
def test_touched_overlap_area():
    sig_shape = (505, 407)
    overlap = (11, 9)
    n_seg = (8, 4)
    segments = Segmentation(n_seg=n_seg,
                            signal_shape=sig_shape,
                            overlap=overlap)

    for i_seg in range(segments.effective_n_seg):
        seg_shape = segments.get_seg_shape(i_seg)
        seg_slice = segments.get_seg_slice(i_seg)
        seg_inner_slice = segments.get_seg_slice(i_seg, inner=True)
        if i_seg != 0:
            with pytest.raises(AssertionError):
                segments.check_area_contained(i_seg, (0, 0), overlap)
        for pt0 in [
                overlap, (overlap[0], 25), (25, overlap[1]), (25, 25),
            (seg_shape[0] - overlap[0] - 1, 25),
            (25, seg_shape[1] - overlap[1] - 1),
            (seg_shape[0] - overlap[0] - 1, seg_shape[1] - overlap[1] - 1)
        ]:
            assert segments.is_contained_coordinate(i_seg, pt0, inner=True)
            segments.check_area_contained(i_seg, pt0, overlap)
            z = np.zeros(sig_shape)
            pt_global = segments.get_global_coordinate(i_seg, pt0)
            update_slice = tuple([
                slice(max(v - r, 0), v + r + 1)
                for v, r in zip(pt_global, overlap)
            ])

            z[update_slice] += 1
            z[seg_inner_slice] = 0

            # The returned slice are given in local coordinates. Take the
            # segment in z to use local coordinate.
            z_seg = z[seg_slice]

            updated_slices = segments.get_touched_overlap_slices(
                i_seg, pt0, overlap)
            # Assert that all selected coordinate are indeed in the update area
            for u_slice in updated_slices:
                assert np.all(z_seg[u_slice] == 1)

            # Assert that all coordinate updated in the overlap area have been
            # selected with at least one slice.
            for u_slice in updated_slices:
                z_seg[u_slice] *= 0
            assert np.all(z == 0)
コード例 #6
0
def test_change_coordinate():
    sig_support = (505, 407)
    overlap = (12, 7)
    n_seg = (4, 4)
    segments = Segmentation(n_seg=n_seg, signal_support=sig_support,
                            overlap=overlap)

    for i_seg in range(segments.effective_n_seg):
        seg_bound = segments.get_seg_bounds(i_seg)
        seg_support = segments.get_seg_support(i_seg)
        origin = tuple([start for start, _ in seg_bound])
        assert segments.get_global_coordinate(i_seg, (0, 0)) == origin
        assert segments.get_local_coordinate(i_seg, origin) == (0, 0)

        corner = tuple([end for _, end in seg_bound])
        assert segments.get_global_coordinate(i_seg, seg_support) == corner
        assert segments.get_local_coordinate(i_seg, corner) == seg_support
コード例 #7
0
def test_segmentation_coverage_overlap():
    sig_support = (505, 407)

    for overlap in [(3, 0), (0, 5), (3, 5), (12, 7)]:
        for h_seg in [5, 7, 9, 13, 15, 17]:
            for w_seg in [3, 11]:
                segments = Segmentation(n_seg=(h_seg, w_seg),
                                        signal_support=sig_support,
                                        overlap=overlap)
                z = np.zeros(sig_support)
                for i_seg in range(segments.effective_n_seg):
                    seg_slice = segments.get_seg_slice(i_seg, inner=True)
                    z[seg_slice] += 1
                    i_seg = segments.increment_seg(i_seg)
                non_overlapping = np.prod(sig_support)
                assert np.sum(z == 1) == non_overlapping

                z = np.zeros(sig_support)
                for i_seg in range(segments.effective_n_seg):
                    seg_slice = segments.get_seg_slice(i_seg)
                    z[seg_slice] += 1
                    i_seg = segments.increment_seg(i_seg)

                h_ov, w_ov = overlap
                h_seg, w_seg = segments.n_seg_per_axis
                expected_overlap = ((h_seg - 1) * sig_support[1] * 2 * h_ov)
                expected_overlap += ((w_seg - 1) * sig_support[0] * 2 * w_ov)

                # Compute the number of pixel where there is more than 2
                # segments overlappping.
                corner_overlap = 4 * (h_seg - 1) * (w_seg - 1) * h_ov * w_ov
                expected_overlap -= 2 * corner_overlap

                non_overlapping -= expected_overlap + corner_overlap
                assert non_overlapping == np.sum(z == 1)
                assert expected_overlap == np.sum(z == 2)
                assert corner_overlap == np.sum(z == 4)
コード例 #8
0
ファイル: dicod_worker.py プロジェクト: pierreHmbt/dicodile
    def recv_task(self):
        # Retrieve different constants from the base communicator and store
        # then in the class.
        params = self.get_params()

        if self.timeout:
            self.timeout *= 3

        self.random_state = params['random_state']
        if isinstance(self.random_state, int):
            self.random_state += self.rank

        self.size_msg = len(params['valid_shape']) + 2

        # Compute the shape of the worker segment.
        self.D = self.get_D()
        n_atoms, n_channels, *atom_support = self.D.shape
        self.overlap = np.array(atom_support) - 1
        self.workers_segments = Segmentation(
            n_seg=params['workers_topology'],
            signal_shape=params['valid_shape'],
            overlap=self.overlap)

        # Receive X and z from the master node.
        worker_shape = self.workers_segments.get_seg_shape(self.rank)
        X_shape = (n_channels, ) + get_full_shape(worker_shape, atom_support)
        if params['has_z0']:
            z0_shape = (n_atoms, ) + worker_shape
            self.z0 = self.get_signal(z0_shape, params['debug'])
        else:
            self.z0 = None
        self.X_worker = self.get_signal(X_shape, params['debug'])

        # Compute the local segmentation for LGCD algorithm

        # If n_seg is not specified, compute the shape of the local segments
        # as the size of an interfering zone.
        n_seg = self.n_seg
        local_seg_shape = None
        if self.n_seg == 'auto':
            n_seg = None
            local_seg_shape = 2 * np.array(atom_support) - 1

        # Get local inner bounds. First, compute the seg_bound without overlap
        # in local coordinates and then convert the bounds in the local
        # coordinate system.
        inner_bounds = self.workers_segments.get_seg_bounds(self.rank,
                                                            inner=True)
        inner_bounds = np.transpose([
            self.workers_segments.get_local_coordinate(self.rank, bound)
            for bound in np.transpose(inner_bounds)
        ])

        self.local_segments = Segmentation(n_seg=n_seg,
                                           seg_shape=local_seg_shape,
                                           inner_bounds=inner_bounds,
                                           full_shape=worker_shape)

        # Initialize the solution
        n_atoms = self.D.shape[0]
        seg_shape = self.workers_segments.get_seg_shape(self.rank)
        if self.z0 is None:
            self.z_hat = np.zeros((n_atoms, ) + seg_shape)
        else:
            self.z_hat = self.z0

        self.info(
            "Start DICOD with {} workers, strategy '{}', soft_lock"
            "={} and n_seg={}({})",
            self.n_jobs,
            self.strategy,
            self.soft_lock,
            self.n_seg,
            self.local_segments.effective_n_seg,
            global_msg=True)

        self.synchronize_workers()
コード例 #9
0
def test_inner_coordinate():
    sig_support = (505, 407)
    overlap = (11, 11)
    n_seg = (4, 4)
    segments = Segmentation(n_seg=n_seg, signal_support=sig_support,
                            overlap=overlap)

    for h_rank in range(n_seg[0]):
        for w_rank in range(n_seg[1]):
            i_seg = h_rank * n_seg[1] + w_rank
            seg_support = segments.get_seg_support(i_seg)
            assert segments.is_contained_coordinate(i_seg, overlap,
                                                    inner=True)

            if h_rank == 0:
                assert segments.is_contained_coordinate(i_seg, (0, overlap[1]),
                                                        inner=True)
            else:
                assert not segments.is_contained_coordinate(
                    i_seg, (overlap[0] - 1, overlap[1]), inner=True)

            if w_rank == 0:
                assert segments.is_contained_coordinate(i_seg, (overlap[0], 0),
                                                        inner=True)
            else:
                assert not segments.is_contained_coordinate(
                    i_seg, (overlap[0], overlap[1] - 1), inner=True)

            if h_rank == 0 and w_rank == 0:
                assert segments.is_contained_coordinate(i_seg, (0, 0),
                                                        inner=True)
            else:
                assert not segments.is_contained_coordinate(
                    i_seg, (overlap[0] - 1, overlap[1] - 1), inner=True)

            if h_rank == n_seg[0] - 1:
                assert segments.is_contained_coordinate(
                    i_seg,
                    (seg_support[0] - 1, seg_support[1] - overlap[1] - 1),
                    inner=True)
            else:
                assert not segments.is_contained_coordinate(
                    i_seg, (seg_support[0] - overlap[0],
                            seg_support[1] - overlap[1] - 1), inner=True)

            if w_rank == n_seg[1] - 1:
                assert segments.is_contained_coordinate(
                   i_seg,
                   (seg_support[0] - overlap[0] - 1, seg_support[1] - 1),
                   inner=True)
            else:
                assert not segments.is_contained_coordinate(
                    i_seg, (seg_support[0] - overlap[0] - 1,
                            seg_support[1] - overlap[1]), inner=True)

            if h_rank == n_seg[0] - 1 and w_rank == n_seg[1] - 1:
                assert segments.is_contained_coordinate(
                    i_seg, (seg_support[0] - 1, seg_support[1] - 1),
                    inner=True)
            else:
                assert not segments.is_contained_coordinate(
                    i_seg, (seg_support[0] - overlap[0],
                            seg_support[1] - overlap[1]), inner=True)
コード例 #10
0
    atom_support = (16, 16)

    run_args = (n_atoms, atom_support, reg, tol, n_workers, random_state)
    if args.no_cache:
        X_hat, pobj = run_without_soft_lock.call(*run_args)[0]
    else:
        X_hat, pobj = run_without_soft_lock(*run_args)

    file_name = f"soft_lock_M{n_workers}_support{atom_support[0]}"
    np.save(f"benchmarks_results/{file_name}_X_hat.npy", X_hat)

    # Compute the worker segmentation for the image,
    n_channels, *sig_support = X_hat.shape
    valid_support = get_valid_support(sig_support, atom_support)
    workers_segments = Segmentation(n_seg=(w_world, w_world),
                                    signal_support=valid_support,
                                    overlap=0)

    fig = plt.figure("recovery")
    fig.patch.set_alpha(0)

    ax = plt.subplot()
    ax.imshow(X_hat.swapaxes(0, 2))
    for i_seg in range(workers_segments.effective_n_seg):
        seg_bounds = np.array(workers_segments.get_seg_bounds(i_seg))
        seg_bounds = seg_bounds + np.array(atom_support) / 2
        ax.vlines(seg_bounds[1], *seg_bounds[0], linestyle='--')
        ax.hlines(seg_bounds[0], *seg_bounds[1], linestyle='--')
    ax.axis('off')
    plt.tight_layout()
コード例 #11
0
def coordinate_descent(X_i, D, reg, z0=None, DtD=None, n_seg='auto',
                       strategy='greedy', tol=1e-5, max_iter=100000,
                       timeout=None, z_positive=False, freeze_support=False,
                       return_ztz=False, timing=False,
                       random_state=None, verbose=0):
    """Coordinate Descent Algorithm for 2D convolutional sparse coding.

    Parameters
    ----------
    X_i : ndarray, shape (n_channels, *sig_support)
        Image to encode on the dictionary D
    D : ndarray, shape (n_atoms, n_channels, *atom_support)
        Current dictionary for the sparse coding
    reg : float
        Regularization parameter
    z0 : ndarray, shape (n_atoms, *valid_support) or None
        Warm start value for z_hat. If not present, z_hat is initialized to 0.
    DtD : ndarray, shape (n_atoms, n_atoms, 2 * valid_support - 1) or None
        Warm start value for DtD. If not present, it is computed on init.
    n_seg : int or 'auto'
        Number of segments to use for each dimension. If set to 'auto' use
        segments of twice the size of the dictionary.
    strategy : str in {strategies}
        Coordinate selection scheme for the coordinate descent. If set to
        'greedy'|'gs-r', the coordinate with the largest value for dz_opt is
        selected. If set to 'random, the coordinate is chosen uniformly on the
        segment. If set to 'gs-q', the value that reduce the most the cost
        function is selected. In this case, dE must holds the value of this
        cost reduction.
    tol : float
        Tolerance for the minimal update size in this algorithm.
    max_iter : int
        Maximal number of iteration run by this algorithm.
    z_positive : boolean
        If set to true, the activations are constrained to be positive.
    freeze_support : boolean
        If set to True, only update the coefficient that are non-zero in z0.
    return_ztz : boolean
        If True, returns the constants ztz and ztX, used to compute D-updates.
    timing : boolean
        If set to True, log the cost and timing information.
    random_state : None or int or RandomState
        current random state to seed the random number generator.
    verbose : int
        Verbosity level of the algorithm.

    Return
    ------
    z_hat : ndarray, shape (n_atoms, *valid_support)
        Activation associated to X_i for the given dictionary D
    """
    n_channels, *sig_support = X_i.shape
    n_atoms, n_channels, *atom_support = D.shape
    valid_support = get_valid_support(sig_support, atom_support)

    if strategy not in STRATEGIES:
        raise ValueError("'The coordinate selection strategy should be in "
                         "{}. Got '{}'.".format(STRATEGIES, strategy))

    # compute sizes for the segments for LGCD. Auto gives segments of size
    # twice the support of the atoms.
    if n_seg == 'auto':
        n_seg = np.array(valid_support) // (2 * np.array(atom_support) - 1)
        n_seg = tuple(np.maximum(1, n_seg))
    segments = Segmentation(n_seg, signal_support=valid_support)

    # Pre-compute constants for maintaining the auxillary variable beta and
    # compute the coordinate update values.
    constants = {}
    constants['norm_atoms'] = compute_norm_atoms(D)
    if DtD is None:
        constants['DtD'] = compute_DtD(D)
    else:
        constants['DtD'] = DtD

    # Initialization of the algorithm variables
    i_seg = -1
    accumulator = 0
    if z0 is None:
        z_hat = np.zeros((n_atoms,) + valid_support)
    else:
        z_hat = np.copy(z0)
    n_coordinates = z_hat.size

    # Get a random number genator from the given random_state
    rng = check_random_state(random_state)
    order = None
    if strategy in ['cyclic', 'cyclic-r', 'random']:
        order = get_order_iterator(z_hat.shape, strategy=strategy,
                                   random_state=rng)

    t_start_init = time.time()
    return_dE = strategy == "gs-q"
    beta, dz_opt, dE = _init_beta(X_i, D, reg, z_i=z0, constants=constants,
                                  z_positive=z_positive, return_dE=return_dE)
    if strategy == "gs-q":
        raise NotImplementedError("This is still WIP")

    if freeze_support:
        freezed_support = z0 == 0
        dz_opt[freezed_support] = 0
    else:
        freezed_support = None

    p_obj, next_log_iter = [], 1
    t_init = time.time() - t_start_init
    if timing:
        p_obj.append((0, t_init, 0, compute_objective(X_i, z_hat, D, reg)))

    n_coordinate_updates = 0
    t_run = 0
    t_select_coord, t_update_coord = [], []
    t_start = time.time()
    if timeout is not None:
        deadline = t_start + timeout
    else:
        deadline = None
    for ii in range(max_iter):
        if ii % 1000 == 0 and verbose > 0:
            print("\r[LGCD:PROGRESS] {:.0f}s - {:7.2%} iterations"
                  .format(t_run, ii / max_iter), end='', flush=True)

        i_seg = segments.increment_seg(i_seg)
        if segments.is_active_segment(i_seg):
            t_start_selection = time.time()
            k0, pt0, dz = _select_coordinate(dz_opt, dE, segments, i_seg,
                                             strategy=strategy, order=order)
            selection_duration = time.time() - t_start_selection
            t_select_coord.append(selection_duration)
            t_run += selection_duration
        else:
            dz = 0

        accumulator = max(abs(dz), accumulator)

        # Update the selected coordinate and beta, only if the update is
        # greater than the convergence tolerance.
        if abs(dz) > tol:
            t_start_update = time.time()

            # update the current solution estimate and beta
            beta, dz_opt, dE = coordinate_update(
                k0, pt0, dz, beta=beta, dz_opt=dz_opt, dE=dE, z_hat=z_hat, D=D,
                reg=reg, constants=constants, z_positive=z_positive,
                freezed_support=freezed_support)
            touched_segs = segments.get_touched_segments(
                pt=pt0, radius=atom_support)
            n_changed_status = segments.set_active_segments(touched_segs)

            # Logging of the time and the cost function if necessary
            update_duration = time.time() - t_start_update
            n_coordinate_updates += 1
            t_run += update_duration
            t_update_coord.append(update_duration)
            if timing and ii + 1 >= next_log_iter:
                p_obj.append((ii + 1, t_run, np.sum(t_select_coord),
                              compute_objective(X_i, z_hat, D, reg)))
                next_log_iter = next_log_iter * 1.3

            # If debug flag CHECK_ACTIVE_SEGMENTS is set, check that all
            # inactive segments should be inactive
            if flags.CHECK_ACTIVE_SEGMENTS and n_changed_status:
                segments.test_active_segment(dz_opt, tol)

        elif strategy in ["greedy", 'gs-r']:
            segments.set_inactive_segments(i_seg)

        # check stopping criterion
        if _check_convergence(segments, tol, ii, dz_opt, n_coordinates,
                              strategy, accumulator=accumulator):
            assert np.all(abs(dz_opt) <= tol)
            if verbose > 0:
                print("\r[LGCD:INFO] converged in {} iterations ({} updates)"
                      .format(ii + 1, n_coordinate_updates))

            break

        # Check is we reach the timeout
        if deadline is not None and time.time() >= deadline:
            if verbose > 0:
                print("\r[LGCD:INFO] Reached timeout. Done {} iterations "
                      "({} updates). Max of |dz|={}."
                      .format(ii + 1, n_coordinate_updates, abs(dz_opt).max()))
            break
    else:
        if verbose > 0:
            print("\r[LGCD:INFO] Reached max_iter. Done {} coordinate "
                  "updates. Max of |dz|={}."
                  .format(n_coordinate_updates, abs(dz_opt).max()))

    print(f"\r[LGCD:{strategy}] "
          f"t_select={np.mean(t_select_coord):.3e}s  "
          f"t_update={np.mean(t_update_coord):.3e}s"
          )

    runtime = time.time() - t_start
    if verbose > 0:
        print("\r[LGCD:INFO] done in {:.3f}s ({:.3f}s)"
              .format(runtime, t_run))

    ztz, ztX = None, None
    if return_ztz:
        ztz = compute_ztz(z_hat, atom_support)
        ztX = compute_ztX(z_hat, X_i)

    p_obj.append([n_coordinate_updates, t_run,
                  compute_objective(X_i, z_hat, D, reg)])

    run_statistics = dict(iterations=ii + 1, runtime=runtime, t_init=t_init,
                          t_run=t_run, n_updates=n_coordinate_updates,
                          t_select=np.mean(t_select_coord),
                          t_update=np.mean(t_update_coord))

    return z_hat, ztz, ztX, p_obj, run_statistics
コード例 #12
0
ファイル: soft_lock.py プロジェクト: pierreHmbt/dicodile
    atom_support = (16, 16)

    run_args = (n_atoms, atom_support, reg, tol, n_jobs, random_state)
    if args.no_cache:
        X_hat, pobj = run_without_soft_lock.call(*run_args)
    else:
        X_hat, pobj = run_without_soft_lock(*run_args)

    file_name = f"soft_lock_M{n_jobs}_support{atom_support[0]}"
    np.save(f"benchmarks_results/{file_name}_X_hat.npy", X_hat)

    # Compute the worker segmentation for the image,
    n_channels, *sig_shape = X_hat.shape
    valid_shape = get_valid_shape(sig_shape, atom_support)
    workers_segments = Segmentation(n_seg=(w_world, w_world),
                                    signal_shape=valid_shape,
                                    overlap=0)

    fig = plt.figure("recovery")
    fig.patch.set_alpha(0)

    ax = plt.subplot()
    ax.imshow(X_hat.swapaxes(0, 2))
    for i_seg in range(workers_segments.effective_n_seg):
        seg_bounds = np.array(workers_segments.get_seg_bounds(i_seg))
        seg_bounds = seg_bounds + np.array(atom_support) / 2
        ax.vlines(seg_bounds[1], *seg_bounds[0], linestyle='--')
        ax.hlines(seg_bounds[0], *seg_bounds[1], linestyle='--')
    ax.axis('off')
    plt.tight_layout()
コード例 #13
0
def coordinate_descent(X_i,
                       D,
                       reg,
                       z0=None,
                       n_seg='auto',
                       strategy='greedy',
                       tol=1e-5,
                       max_iter=100000,
                       timeout=None,
                       z_positive=False,
                       freeze_support=False,
                       return_ztz=False,
                       timing=False,
                       random_state=None,
                       verbose=0):
    """Coordinate Descent Algorithm for 2D convolutional sparse coding.

    Parameters
    ----------
    X_i : ndarray, shape (n_channels, *sig_shape)
        Image to encode on the dictionary D
    z_i : ndarray, shape (n_atoms, *valid_shape)
        Warm start value for z_hat
    D : ndarray, shape (n_atoms, n_channels, *atom_shape)
        Current dictionary for the sparse coding
    reg : float
        Regularization parameter
    n_seg : int or { 'auto' }
        Number of segments to use for each dimension. If set to 'auto' use
        segments of twice the size of the dictionary.
    tol : float
        Tolerance for the minimal update size in this algorithm.
    strategy : str in { 'greedy' | 'random' | 'gs-r' | 'gs-q' }
        Coordinate selection scheme for the coordinate descent. If set to
        'greedy'|'gs-r', the coordinate with the largest value for dz_opt is
        selected. If set to 'random, the coordinate is chosen uniformly on the
        segment. If set to 'gs-q', the value that reduce the most the cost
        function is selected. In this case, dE must holds the value of this
        cost reduction.
    max_iter : int
        Maximal number of iteration run by this algorithm.
    z_positive : boolean
        If set to true, the activations are constrained to be positive.
    freeze_support : boolean
        If set to True, only update the coefficient that are non-zero in z0.
    timing : boolean
        If set to True, log the cost and timing information.
    random_state : None or int or RandomState
        current random state to seed the random number generator.
    verbose : int
        Verbosity level of the algorithm.

    Return
    ------
    z_hat : ndarray, shape (n_atoms, *valid_shape)
        Activation associated to X_i for the given dictionary D
    """
    n_channels, *sig_shape = X_i.shape
    n_atoms, n_channels, *atom_shape = D.shape
    valid_shape = tuple([
        size_ax - size_atom_ax + 1
        for size_ax, size_atom_ax in zip(sig_shape, atom_shape)
    ])

    # compute sizes for the segments for LGCD
    if n_seg == 'auto':
        n_seg = []
        for axis_size, atom_size in zip(valid_shape, atom_shape):
            n_seg.append(max(axis_size // (2 * atom_size - 1), 1))
    segments = Segmentation(n_seg, signal_shape=valid_shape)

    # Pre-compute some quantities
    constants = {}
    constants['norm_atoms'] = compute_norm_atoms(D)
    constants['DtD'] = compute_DtD(D)

    # Initialization of the algorithm variables
    i_seg = -1
    p_obj, next_cost = [], 1
    accumulator = 0
    if z0 is None:
        z_hat = np.zeros((n_atoms, ) + valid_shape)
    else:
        z_hat = np.copy(z0)
    n_coordinates = z_hat.size

    t_update = 0
    t_start_update = time.time()
    return_dE = strategy == "gs-q"
    beta, dz_opt, dE = _init_beta(X_i,
                                  D,
                                  reg,
                                  z_i=z0,
                                  constants=constants,
                                  z_positive=z_positive,
                                  return_dE=return_dE)
    if strategy == "gs-q":
        raise NotImplementedError("This is still WIP")

    if freeze_support:
        freezed_support = z0 == 0
        dz_opt[freezed_support] = 0
    else:
        freezed_support = None

    t_start = time.time()
    n_coordinate_updates = 0
    if timeout is not None:
        deadline = t_start + timeout
    else:
        deadline = None
    for ii in range(max_iter):
        if ii % 1000 == 0 and verbose > 0:
            print("\r[LGCD:PROGRESS] {:.0f}s - {:7.2%} iterations".format(
                t_update, ii / max_iter),
                  end='',
                  flush=True)

        i_seg = segments.increment_seg(i_seg)
        if segments.is_active_segment(i_seg):
            k0, pt0, dz = _select_coordinate(dz_opt,
                                             dE,
                                             segments,
                                             i_seg,
                                             strategy=strategy,
                                             random_state=random_state)
        else:
            k0, pt0, dz = None, None, 0

        accumulator = max(abs(dz), accumulator)

        # Update the selected coordinate and beta, only if the update is
        # greater than the convergence tolerance.
        if abs(dz) > tol:
            n_coordinate_updates += 1

            # update beta
            beta, dz_opt, dE = coordinate_update(
                k0,
                pt0,
                dz,
                beta=beta,
                dz_opt=dz_opt,
                dE=dE,
                z_hat=z_hat,
                D=D,
                reg=reg,
                constants=constants,
                z_positive=z_positive,
                freezed_support=freezed_support)
            touched_segs = segments.get_touched_segments(pt=pt0,
                                                         radius=atom_shape)
            n_changed_status = segments.set_active_segments(touched_segs)

            if flags.CHECK_ACTIVE_SEGMENTS and n_changed_status:
                segments.test_active_segment(dz_opt, tol)

            t_update += time.time() - t_start_update
            if timing:
                if ii >= next_cost:
                    p_obj.append(
                        (ii, t_update, compute_objective(X_i, z_hat, D, reg)))
                    next_cost = next_cost * 2
            t_start_update = time.time()
        elif strategy in ["greedy", 'gs-r']:
            segments.set_inactive_segments(i_seg)

        # check stopping criterion
        if _check_convergence(segments,
                              tol,
                              ii,
                              dz_opt,
                              n_coordinates,
                              strategy,
                              accumulator=accumulator):
            assert np.all(abs(dz_opt) <= tol)
            if verbose > 0:
                print("\r[LGCD:INFO] converged after {} iterations".format(ii +
                                                                           1))

            break

        # Check is we reach the timeout
        if deadline is not None and time.time() >= deadline:
            if verbose > 0:
                print("\r[LGCD:INFO] Reached timeout. Done {} coordinate "
                      "updates. Max of |dz|={}.".format(
                          n_coordinate_updates,
                          abs(dz_opt).max()))
            break
    else:
        if verbose > 0:
            print("\r[LGCD:INFO] Reached max_iter. Done {} coordinate "
                  "updates. Max of |dz|={}.".format(n_coordinate_updates,
                                                    abs(dz_opt).max()))

    runtime = time.time() - t_start
    if verbose > 0:
        print("\r[LGCD:INFO] done in {:.3}s".format(runtime))

    ztz, ztX = None, None
    if return_ztz:
        ztz = compute_ztz(z_hat, atom_shape)
        ztX = compute_ztX(z_hat, X_i)

    p_obj.append([n_coordinate_updates, t_update, None])

    return z_hat, ztz, ztX, p_obj, None