Ejemplo n.º 1
0
def test_get_cost(solver, X, D_hat, requires_dicodile):
    """Test for valid values."""

    with get_z_encoder_for(solver=solver,
                           X=X,
                           D_hat=D_hat,
                           n_atoms=N_ATOMS,
                           n_times_atom=N_TIMES_ATOM,
                           n_jobs=2) as z_encoder:
        initial_cost = z_encoder.get_cost()

        z_encoder.compute_z()
        z_hat = z_encoder.get_z_hat()
        final_cost = z_encoder.get_cost()

        assert final_cost < initial_cost

        X_hat = construct_X_multi(z_hat, D_hat, n_channels=N_CHANNELS)
        cost = compute_objective(X=X,
                                 X_hat=X_hat,
                                 z_hat=z_hat,
                                 reg=0.1,
                                 D=D_hat)

        assert np.isclose(cost, final_cost)
Ejemplo n.º 2
0
def test_fast_cost():
    """Test that _shifted_objective_uv compute the right thing"""
    # Generate synchronous D
    n_times_atom, n_times = 10, 40
    n_channels = 3
    n_atoms = 2
    n_trials = 4

    rng = np.random.RandomState()
    X = rng.normal(size=(n_trials, n_channels, n_times))
    z = rng.normal(size=(n_trials, n_atoms, n_times - n_times_atom + 1))

    constants = _get_d_update_constants(X, z)

    def objective(uv):
        X_hat = construct_X_multi(z, D=uv, n_channels=n_channels)
        res = X - X_hat
        return .5 * np.sum(res * res)

    for _ in range(5):
        uv = rng.normal(size=(n_atoms, n_channels + n_times_atom))

        cost_fast = compute_objective(D=uv, constants=constants)
        cost_full = objective(uv)
        assert np.isclose(cost_full, cost_fast)
Ejemplo n.º 3
0
 def func(d0):
     D0 = d0.reshape(n_atoms, n_channels, n_times_atom)
     X_hat = construct_X_multi(z, D=D0)
     return compute_objective(X, X_hat, loss=loss, loss_params=loss_params)
Ejemplo n.º 4
0
 def objective(uv):
     X_hat = construct_X_multi(z, D=uv, n_channels=n_channels)
     return compute_objective(X, X_hat, loss='l2')
Ejemplo n.º 5
0
 def func(uv0):
     uv0 = uv0.reshape(n_atoms, n_channels + n_times_atom)
     X_hat = construct_X_multi(z, D=uv0, n_channels=n_channels)
     return compute_objective(X, X_hat, loss=loss, loss_params=loss_params)