def test_learn_atoms(): """Test learning of atoms.""" X, ds, z = simulate_data(n_trials, n_times, n_times_atom, n_atoms) d_hat, _ = update_d(X, z, n_times_atom) assert np.allclose(ds, d_hat) X_hat = construct_X(z, d_hat) assert np.allclose(X, X_hat, rtol=1e-05, atol=1e-12)
def test_update_d(): """Test vanilla d update.""" rng = check_random_state(42) X, ds, z = simulate_data(n_trials, n_times, n_times_atom, n_atoms) ds_init = rng.randn(n_atoms, n_times_atom) # This number of iteration is 1 in the general case, but needs to be # increased to compare with update_d n_iter_d_block = 5 # All solvers should give the same results d_hat_0, _ = update_d(X, z, n_times_atom, lambd0=None, ds_init=ds_init) d_hat_1, _ = update_d_block(X, z, n_times_atom, lambd0=None, ds_init=ds_init, n_iter=n_iter_d_block) assert np.allclose(d_hat_0, d_hat_1, rtol=1e-5)