示例#1
0
def test_update_z_sample_weights():
    """Test z update with weights."""
    rng = check_random_state(42)
    X, ds, z = simulate_data(n_trials, n_times, n_times_atom, n_atoms)
    b_hat_0 = rng.randn(n_atoms * (n_times - n_times_atom + 1))

    # Having sample_weights all identical is equivalent to having
    # sample_weights=None and a scaled regularization
    factor = 1.6
    sample_weights = np.ones_like(X) * factor
    for solver in ('l_bfgs', 'ista', 'fista'):
        z_0 = update_z(X,
                       ds,
                       reg * factor,
                       n_times_atom,
                       solver=solver,
                       solver_kwargs=dict(factr=1e7, max_iter=50),
                       b_hat_0=b_hat_0.copy(),
                       sample_weights=sample_weights)
        z_1 = update_z(X,
                       ds,
                       reg,
                       n_times_atom,
                       solver=solver,
                       solver_kwargs=dict(factr=1e7, max_iter=50),
                       b_hat_0=b_hat_0.copy(),
                       sample_weights=None)
        assert_allclose(z_0, z_1, rtol=1e-4)

    # All solvers should give the same results
    sample_weights = np.abs(rng.randn(*X.shape))
    sample_weights /= sample_weights.mean()
    z_list = []
    for solver in ('l_bfgs', 'ista', 'fista'):
        z_hat = update_z(X,
                         ds,
                         reg,
                         n_times_atom,
                         solver=solver,
                         solver_kwargs=dict(factr=1e7, max_iter=2000),
                         b_hat_0=b_hat_0.copy(),
                         sample_weights=sample_weights)
        z_list.append(z_hat)
    assert_allclose(z_list[0][z != 0], z_list[1][z != 0], rtol=1e-3)
    assert_allclose(z_list[0][z != 0], z_list[2][z != 0], rtol=1e-3)

    # And using no sample weights should give different results
    z_hat = update_z(X,
                     ds,
                     reg,
                     n_times_atom,
                     solver=solver,
                     solver_kwargs=dict(factr=1e7, max_iter=2000),
                     b_hat_0=b_hat_0.copy(),
                     sample_weights=None)
    assert_raises(AssertionError, assert_allclose, z_list[0][z != 0],
                  z_hat[z != 0], 1e-3)
示例#2
0
def test_learn_codes():
    """Test learning of codes."""
    thresh = 0.25

    X, ds, z = simulate_data(n_trials, n_times, n_times_atom, n_atoms)

    for solver in ('l-bfgs', 'ista', 'fista'):
        z_hat = update_z(X,
                         ds,
                         reg,
                         solver=solver,
                         solver_kwargs=dict(factr=1e11, max_iter=50))

        X_hat = construct_X(z_hat, ds)
        assert np.corrcoef(X.ravel(), X_hat.ravel())[1, 1] > 0.99
        assert np.max(X - X_hat) < 0.1

        # Find position of non-zero entries
        idx = np.ravel_multi_index(z[0].nonzero(), z[0].shape)
        loc_x, loc_y = np.where(z_hat[0] > thresh)
        # shift position by half the length of atom
        idx_hat = np.ravel_multi_index((loc_x, loc_y), z_hat[0].shape)
        # make sure that the positions are a subset of the positions
        # in the original z
        mask = np.in1d(idx_hat, idx)
        assert np.sum(mask) == len(mask)
示例#3
0
def test_z0_read_only():
    # If n_atoms == 1, the reshape in update_z does not copy the data (cf #26)
    n_atoms = 1
    X, ds, z = simulate_data(n_trials, n_times, n_times_atom, n_atoms)
    z.flags.writeable = False
    update_z(X, ds, 0.1, z0=z, solver='ista')