예제 #1
0
def test_learn_atoms():
    """Test learning of atoms."""
    X, ds, z = simulate_data(n_trials, n_times, n_times_atom, n_atoms)
    d_hat, _ = update_d(X, z, n_times_atom)

    assert np.allclose(ds, d_hat)

    X_hat = construct_X(z, d_hat)
    assert np.allclose(X, X_hat, rtol=1e-05, atol=1e-12)
예제 #2
0
def test_update_d():
    """Test vanilla d update."""
    rng = check_random_state(42)
    X, ds, z = simulate_data(n_trials, n_times, n_times_atom, n_atoms)
    ds_init = rng.randn(n_atoms, n_times_atom)

    # This number of iteration is 1 in the general case, but needs to be
    # increased to compare with update_d
    n_iter_d_block = 5

    # All solvers should give the same results
    d_hat_0, _ = update_d(X, z, n_times_atom, lambd0=None, ds_init=ds_init)
    d_hat_1, _ = update_d_block(X, z, n_times_atom, lambd0=None,
                                ds_init=ds_init, n_iter=n_iter_d_block)
    assert np.allclose(d_hat_0, d_hat_1, rtol=1e-5)