def test_learn_codes(): """Test learning of codes.""" thresh = 0.25 X, ds, z = simulate_data(n_trials, n_times, n_times_atom, n_atoms) for solver in ('l-bfgs', 'ista', 'fista'): z_hat = update_z(X, ds, reg, solver=solver, solver_kwargs=dict(factr=1e11, max_iter=50)) X_hat = construct_X(z_hat, ds) assert np.corrcoef(X.ravel(), X_hat.ravel())[1, 1] > 0.99 assert np.max(X - X_hat) < 0.1 # Find position of non-zero entries idx = np.ravel_multi_index(z[0].nonzero(), z[0].shape) loc_x, loc_y = np.where(z_hat[0] > thresh) # shift position by half the length of atom idx_hat = np.ravel_multi_index((loc_x, loc_y), z_hat[0].shape) # make sure that the positions are a subset of the positions # in the original z mask = np.in1d(idx_hat, idx) assert np.sum(mask) == len(mask)
def test_learn_atoms(): """Test learning of atoms.""" X, ds, z = simulate_data(n_trials, n_times, n_times_atom, n_atoms) d_hat, _ = update_d(X, z, n_times_atom) assert np.allclose(ds, d_hat) X_hat = construct_X(z, d_hat) assert np.allclose(X, X_hat, rtol=1e-05, atol=1e-12)