def test_update_z_multi_decrease_cost_function(loss, solver):
    n_trials, n_channels, n_times = 2, 3, 100
    n_times_atom, n_atoms = 10, 4
    n_times_valid = n_times - n_times_atom + 1
    reg = 0
    loss_params = dict(gamma=1, sakoe_chiba_band=n_times_atom // 2)

    rng = np.random.RandomState(0)
    X = rng.randn(n_trials, n_channels, n_times)
    uv = rng.randn(n_atoms, n_channels + n_times_atom)
    z = rng.randn(n_trials, n_atoms, n_times_valid)

    if loss == 'whitening':
        loss_params['ar_model'], X = whitening(X, ordar=10)

    loss_0 = compute_X_and_objective_multi(X=X, z_hat=z, D_hat=uv, reg=reg,
                                           feasible_evaluation=False,
                                           loss=loss, loss_params=loss_params)

    z_hat, ztz, ztX = update_z_multi(X, uv, reg, z0=z, solver=solver,
                                     loss=loss, loss_params=loss_params,
                                     return_ztz=True)

    loss_1 = compute_X_and_objective_multi(X=X, z_hat=z_hat, D_hat=uv,
                                           reg=reg, feasible_evaluation=False,
                                           loss=loss, loss_params=loss_params)
    assert loss_1 < loss_0

    if loss == 'l2':
        assert np.allclose(ztz, compute_ztz(z_hat, n_times_atom))
        assert np.allclose(ztX, compute_ztX(z_hat, X))
Beispiel #2
0
def test_get_sufficient_statistics(solver, X, D_hat, requires_dicodile):
    """Test for valid values."""

    z_encoder = get_z_encoder_for(solver=solver,
                                  X=X,
                                  D_hat=D_hat,
                                  n_atoms=N_ATOMS,
                                  n_times_atom=N_TIMES_ATOM,
                                  n_jobs=2)

    z_encoder.compute_z()
    z_hat = z_encoder.get_z_hat()

    ztz, ztX = z_encoder.get_sufficient_statistics()
    assert ztz is not None and np.allclose(ztz, compute_ztz(
        z_hat, N_TIMES_ATOM))

    assert ztX is not None and np.allclose(ztX, compute_ztX(z_hat, X))
Beispiel #3
0
def test_sparse_convolve():
    n_times = 128
    n_times_atom = 21
    n_channels = 2
    n_atoms = 3
    n_times_valid = n_times - n_times_atom + 1
    density = 0.1
    n_trials = 4
    rng = np.random.RandomState(0)
    X = rng.randn(n_trials, n_channels, n_times)

    z = sparse.random(n_trials,
                      n_atoms * n_times_valid,
                      density=density,
                      random_state=0)
    z = z.toarray().reshape(n_trials, n_atoms, n_times_valid)
    z_lil = convert_to_list_of_lil(z)

    ztX_0 = cython_code._fast_compute_ztX(z_lil, X)
    ztX_1 = compute_ztX(z, X)
    assert_allclose(ztX_0, ztX_1, atol=1e-16)
Beispiel #4
0
def test_cd():
    n_trials, n_channels, n_times = 5, 3, 100
    n_times_atom, n_atoms = 10, 4
    n_times_valid = n_times - n_times_atom + 1
    reg = 1

    rng = np.random.RandomState(0)
    uv = rng.randn(n_atoms, n_channels + n_times_atom)
    z = abs(rng.randn(n_trials, n_atoms, n_times_valid))
    z_gen = abs(rng.randn(n_trials, n_atoms, n_times_valid))
    z[z < 1] = 0
    z_gen[z_gen < 1] = 0
    z0 = z[0]

    X = construct_X_multi(z_gen, D=uv, n_channels=n_channels)

    loss_0 = compute_X_and_objective_multi(X=X,
                                           z_hat=z_gen,
                                           D_hat=uv,
                                           reg=reg,
                                           loss='l2',
                                           feasible_evaluation=False)

    constants = {}
    constants['DtD'] = compute_DtD(uv, n_channels)

    # Ensure that the initialization is good, by using a nearly optimal point
    # and verifying that the cost does not goes up.
    z_hat, ztz, ztX = update_z_multi(X,
                                     D=uv,
                                     reg=reg,
                                     z0=z_gen,
                                     solver="lgcd",
                                     solver_kwargs={
                                         'max_iter': 5,
                                         'tol': 1e-5
                                     },
                                     return_ztz=True)
    assert np.allclose(ztz, compute_ztz(z_hat, n_times_atom))
    assert np.allclose(ztX, compute_ztX(z_hat, X))

    loss_1 = compute_X_and_objective_multi(X=X,
                                           z_hat=z_hat,
                                           D_hat=uv,
                                           reg=reg,
                                           loss='l2',
                                           feasible_evaluation=False)
    assert loss_1 <= loss_0, "Bad initialization in greedy CD."

    z_hat, pobj, times = _coordinate_descent_idx(X[0],
                                                 uv,
                                                 constants,
                                                 reg,
                                                 debug=True,
                                                 timing=True,
                                                 z0=z0,
                                                 max_iter=10000)

    try:
        assert all([p1 >= p2 for p1, p2 in zip(pobj[:-1], pobj[1:])]), "oups"
    except AssertionError:
        import matplotlib.pyplot as plt
        plt.plot(pobj)
        plt.show()
        raise
Beispiel #5
0
def test_cd(use_sparse_lil):
    n_trials, n_channels, n_times = 5, 3, 100
    n_times_atom, n_atoms = 10, 4
    n_times_valid = n_times - n_times_atom + 1
    reg = 1

    uv = np.random.randn(n_atoms, n_channels + n_times_atom)
    if use_sparse_lil:
        density = .1
        z = [
            sparse.random(n_atoms,
                          n_times_valid,
                          format='lil',
                          density=density) for _ in range(n_trials)
        ]
        z_gen = [
            sparse.random(n_atoms,
                          n_times_valid,
                          format='lil',
                          density=density) for _ in range(n_trials)
        ]
        z0 = z[0]
    else:
        z = abs(np.random.randn(n_trials, n_atoms, n_times_valid))
        z_gen = abs(np.random.randn(n_trials, n_atoms, n_times_valid))
        z[z < 1] = 0
        z_gen[z_gen < 1] = 0
        z0 = z[0]

    X = construct_X_multi(z_gen, D=uv, n_channels=n_channels)

    loss_0 = compute_X_and_objective_multi(X=X,
                                           z_hat=z_gen,
                                           D_hat=uv,
                                           reg=reg,
                                           loss='l2',
                                           feasible_evaluation=False)

    constants = {}
    constants['DtD'] = compute_DtD(uv, n_channels)

    # Ensure that the initialization is good, by using a nearly optimal point
    # and verifying that the cost does not goes up.
    z_hat, ztz, ztX = update_z_multi(X,
                                     D=uv,
                                     reg=reg,
                                     z0=z_gen,
                                     solver="lgcd",
                                     solver_kwargs={
                                         'max_iter': 5,
                                         'tol': 1e-5
                                     })
    if use_sparse_lil and cython_code._CYTHON_AVAILABLE:
        from alphacsc.cython_code import _fast_compute_ztz, _fast_compute_ztX
        assert np.allclose(ztz, _fast_compute_ztz(z_hat, n_times_atom))
        assert np.allclose(ztX, _fast_compute_ztX(z_hat, X))

    else:
        assert np.allclose(ztz, compute_ztz(z_hat, n_times_atom))
        assert np.allclose(ztX, compute_ztX(z_hat, X))

    loss_1 = compute_X_and_objective_multi(X=X,
                                           z_hat=z_hat,
                                           D_hat=uv,
                                           reg=reg,
                                           loss='l2',
                                           feasible_evaluation=False)
    assert loss_1 <= loss_0, "Bad initialization in greedy CD."

    z_hat, pobj, times = _coordinate_descent_idx(X[0],
                                                 uv,
                                                 constants,
                                                 reg,
                                                 debug=True,
                                                 timing=True,
                                                 z0=z0,
                                                 max_iter=10000)

    try:
        assert all([p1 >= p2 for p1, p2 in zip(pobj[:-1], pobj[1:])]), "oups"
    except AssertionError:
        import matplotlib.pyplot as plt
        plt.plot(pobj)
        plt.show()
        raise