Ejemplo n.º 1
0
def run_without_soft_lock(n_atoms=25,
                          atom_support=(12, 12),
                          reg=.01,
                          tol=5e-2,
                          n_workers=100,
                          random_state=60):
    rng = np.random.RandomState(random_state)

    X = get_mandril()
    D_init = init_dictionary(X, n_atoms, atom_support, random_state=rng)
    lmbd_max = get_lambda_max(X, D_init).max()
    reg_ = reg * lmbd_max

    z_hat, *_ = dicod(X,
                      D_init,
                      reg_,
                      max_iter=1000000,
                      n_workers=n_workers,
                      tol=tol,
                      strategy='greedy',
                      verbose=1,
                      soft_lock='none',
                      z_positive=False,
                      timing=False)
    pobj = compute_objective(X, z_hat, D_init, reg_)
    z_hat = np.clip(z_hat, -1e3, 1e3)
    print("[DICOD] final cost : {}".format(pobj))

    X_hat = reconstruct(z_hat, D_init)
    X_hat = np.clip(X_hat, 0, 1)
    return X_hat, pobj
Ejemplo n.º 2
0
def run_one_grid(n_atoms, atom_support, reg, n_workers, grid, tol, soft_lock,
                 dicod_args, random_state):

    tag = f"[{soft_lock} - {reg:.0e} - {random_state[0]}]"
    random_state = random_state[1]

    # Generate a problem
    print(
        colorify(79 * "=" + f"\n{tag} Start with {n_workers} workers\n" +
                 79 * "="))
    X = get_mandril()
    D = init_dictionary(X, n_atoms, atom_support, random_state=random_state)
    reg_ = reg * get_lambda_max(X, D).max()

    if grid:
        w_world = 'auto'
    else:
        w_world = n_workers

    z_hat, *_, run_statistics = dicod(X,
                                      D,
                                      reg=reg_,
                                      n_seg='auto',
                                      strategy='greedy',
                                      w_world=w_world,
                                      n_workers=n_workers,
                                      timing=False,
                                      tol=tol,
                                      soft_lock=soft_lock,
                                      **dicod_args)

    runtime = run_statistics['runtime']
    sparsity = len(z_hat.nonzero()[0]) / z_hat.size

    print(
        colorify("=" * 79 + f"\n{tag} End for {n_workers} workers "
                 f"in {runtime:.1e}\n" + "=" * 79,
                 color=GREEN))

    return ResultItem(n_atoms=n_atoms,
                      atom_support=atom_support,
                      reg=reg,
                      n_workers=n_workers,
                      grid=grid,
                      tol=tol,
                      soft_lock=soft_lock,
                      random_state=random_state,
                      dicod_args=dicod_args,
                      sparsity=sparsity,
                      **run_statistics)
Ejemplo n.º 3
0
def get_problem(n_atoms, atom_support, seed):
    X = get_mandril()

    rng = check_random_state(seed)

    n_channels, *sig_shape = X.shape
    valid_shape = get_valid_shape(sig_shape, atom_support)

    indices = np.c_[[
        rng.randint(size_ax, size=(n_atoms)) for size_ax in valid_shape
    ]].T
    D = np.empty(shape=(n_atoms, n_channels, *atom_support))
    for k, pt in enumerate(indices):
        D_slice = tuple(
            [Ellipsis] +
            [slice(v, v + size_ax) for v, size_ax in zip(pt, atom_support)])
        D[k] = X[D_slice]
    sum_axis = tuple(range(1, D.ndim))
    D /= np.sqrt(np.sum(D * D, axis=sum_axis, keepdims=True))

    return X, D
Ejemplo n.º 4
0
def run_one_grid(n_atoms, atom_support, reg, n_jobs, grid, tol, random_state,
                 verbose):
    # Generate a problem
    X = get_mandril()
    D = init_dictionary(X, n_atoms, atom_support, random_state=random_state)
    reg_ = reg * get_lambda_max(X, D).max()

    if grid:
        w_world = 'auto'
    else:
        w_world = n_jobs

    dicod_kwargs = dict(z_positive=False,
                        soft_lock='corner',
                        timeout=None,
                        max_iter=int(1e8))
    z_hat, *_, pobj, cost = dicod(X,
                                  D,
                                  reg=reg_,
                                  n_seg='auto',
                                  strategy='greedy',
                                  w_world=w_world,
                                  n_jobs=n_jobs,
                                  timing=True,
                                  tol=tol,
                                  verbose=verbose,
                                  **dicod_kwargs)

    sparsity = len(z_hat.nonzero()[0]) / z_hat.size

    return ResultItem(n_atoms=n_atoms,
                      atom_support=atom_support,
                      reg=reg,
                      n_jobs=n_jobs,
                      grid=grid,
                      tol=tol,
                      random_state=random_state,
                      sparsity=sparsity,
                      pobj=pobj)
Ejemplo n.º 5
0
def run_one_scaling_2d(n_atoms, atom_support, reg, n_workers, strategy, tol,
                       dicod_args, random_state):
    tag = f"[{strategy} - {reg:.0e} - {random_state[0]}]"
    random_state = random_state[1]

    # Generate a problem
    print(
        colorify(79 * "=" + f"\n{tag} Start with {n_workers} workers\n" +
                 79 * "="))
    X = get_mandril()
    D = init_dictionary(X, n_atoms, atom_support, random_state=random_state)
    reg_ = reg * get_lambda_max(X, D).max()

    z_hat, *_, run_statistics = dicod(X,
                                      D,
                                      reg=reg_,
                                      strategy=strategy,
                                      n_workers=n_workers,
                                      tol=tol,
                                      **dicod_args)

    runtime = run_statistics['runtime']
    sparsity = len(z_hat.nonzero()[0]) / z_hat.size
    print(
        colorify('=' * 79 + f"\n{tag} End with {n_workers} workers for reg="
                 f"{reg:.0e} in {runtime:.1e}\n" + "=" * 79,
                 color=GREEN))

    return ResultItem(n_atoms=n_atoms,
                      atom_support=atom_support,
                      reg=reg,
                      n_workers=n_workers,
                      strategy=strategy,
                      tol=tol,
                      dicod_args=dicod_args,
                      random_state=random_state,
                      sparsity=sparsity,
                      **run_statistics)
Ejemplo n.º 6
0
def test_fetch_mandril():
    data = get_mandril()
    assert (3, 512, 512 == data.shape)