def test_update_z_sample_weights(): """Test z update with weights.""" rng = check_random_state(42) X, ds, z = simulate_data(n_trials, n_times, n_times_atom, n_atoms) b_hat_0 = rng.randn(n_atoms * (n_times - n_times_atom + 1)) # Having sample_weights all identical is equivalent to having # sample_weights=None and a scaled regularization factor = 1.6 sample_weights = np.ones_like(X) * factor for solver in ('l_bfgs', 'ista', 'fista'): z_0 = update_z(X, ds, reg * factor, n_times_atom, solver=solver, solver_kwargs=dict(factr=1e7, max_iter=50), b_hat_0=b_hat_0.copy(), sample_weights=sample_weights) z_1 = update_z(X, ds, reg, n_times_atom, solver=solver, solver_kwargs=dict(factr=1e7, max_iter=50), b_hat_0=b_hat_0.copy(), sample_weights=None) assert_allclose(z_0, z_1, rtol=1e-4) # All solvers should give the same results sample_weights = np.abs(rng.randn(*X.shape)) sample_weights /= sample_weights.mean() z_list = [] for solver in ('l_bfgs', 'ista', 'fista'): z_hat = update_z(X, ds, reg, n_times_atom, solver=solver, solver_kwargs=dict(factr=1e7, max_iter=2000), b_hat_0=b_hat_0.copy(), sample_weights=sample_weights) z_list.append(z_hat) assert_allclose(z_list[0][z != 0], z_list[1][z != 0], rtol=1e-3) assert_allclose(z_list[0][z != 0], z_list[2][z != 0], rtol=1e-3) # And using no sample weights should give different results z_hat = update_z(X, ds, reg, n_times_atom, solver=solver, solver_kwargs=dict(factr=1e7, max_iter=2000), b_hat_0=b_hat_0.copy(), sample_weights=None) assert_raises(AssertionError, assert_allclose, z_list[0][z != 0], z_hat[z != 0], 1e-3)
def test_learn_codes(): """Test learning of codes.""" thresh = 0.25 X, ds, z = simulate_data(n_trials, n_times, n_times_atom, n_atoms) for solver in ('l-bfgs', 'ista', 'fista'): z_hat = update_z(X, ds, reg, solver=solver, solver_kwargs=dict(factr=1e11, max_iter=50)) X_hat = construct_X(z_hat, ds) assert np.corrcoef(X.ravel(), X_hat.ravel())[1, 1] > 0.99 assert np.max(X - X_hat) < 0.1 # Find position of non-zero entries idx = np.ravel_multi_index(z[0].nonzero(), z[0].shape) loc_x, loc_y = np.where(z_hat[0] > thresh) # shift position by half the length of atom idx_hat = np.ravel_multi_index((loc_x, loc_y), z_hat[0].shape) # make sure that the positions are a subset of the positions # in the original z mask = np.in1d(idx_hat, idx) assert np.sum(mask) == len(mask)
def test_z0_read_only(): # If n_atoms == 1, the reshape in update_z does not copy the data (cf #26) n_atoms = 1 X, ds, z = simulate_data(n_trials, n_times, n_times_atom, n_atoms) z.flags.writeable = False update_z(X, ds, 0.1, z0=z, solver='ista')