Esempio n. 1
0
def test_get_encoder_for_error_reg(X, D_hat):
    """Tests for invalid value of `reg`."""

    with pytest.raises(AssertionError, match="reg value cannot be None."):
        get_z_encoder_for(X=X,
                          D_hat=D_hat,
                          n_atoms=N_ATOMS,
                          n_times_atom=N_TIMES_ATOM,
                          reg=None,
                          n_jobs=2)
Esempio n. 2
0
def test_get_encoder_for_error_solver_kwargs(X, D_hat):
    """Tests for invalid value of `solver_kwargs`."""

    with pytest.raises(AssertionError, match=".*solver_kwargs should.*"):
        get_z_encoder_for(solver_kwargs=None,
                          X=X,
                          D_hat=D_hat,
                          n_atoms=N_ATOMS,
                          n_times_atom=N_TIMES_ATOM,
                          n_jobs=2)
Esempio n. 3
0
def test_get_encoder_for_error_D_hat(X, D_init):
    """Tests for invalid values of `D_hat`."""

    with pytest.raises(AssertionError,
                       match="D_hat should be a valid array of shape.*"):
        get_z_encoder_for(X=X,
                          D_hat=D_init,
                          n_atoms=N_ATOMS,
                          n_times_atom=N_TIMES_ATOM,
                          n_jobs=2)
Esempio n. 4
0
def test_get_encoder_for_dicodile_error_rank1(X, D_hat, requires_dicodile):
    """Test for invalid rank1 value for dicodile backend."""

    with pytest.raises(AssertionError):
        get_z_encoder_for(solver='dicodile',
                          X=X,
                          D_hat=D_hat,
                          n_atoms=N_ATOMS,
                          n_times_atom=N_TIMES_ATOM,
                          n_jobs=2)
Esempio n. 5
0
def test_get_encoder_for_error_uv_constraint(X, D_hat, uv_constraint):
    """Tests for invalid values of `uv_constraint`."""

    with pytest.raises(AssertionError,
                       match="unrecognized uv_constraint type.*"):
        get_z_encoder_for(X=X,
                          D_hat=D_hat,
                          n_atoms=N_ATOMS,
                          n_times_atom=N_TIMES_ATOM,
                          uv_constraint=uv_constraint,
                          n_jobs=2)
Esempio n. 6
0
def test_get_encoder_for_error_loss_params(X, D_hat):
    """Tests for invalid value of `loss_params`."""

    with pytest.raises(AssertionError,
                       match="loss_params should be a valid dict or None."):
        get_z_encoder_for(X=X,
                          D_hat=D_hat,
                          n_atoms=N_ATOMS,
                          n_times_atom=N_TIMES_ATOM,
                          loss_params=42,
                          n_jobs=2)
Esempio n. 7
0
def test_get_encoder_for_error_loss(X, D_hat, loss):
    """Tests for invalid values of `loss`."""

    with pytest.raises(AssertionError,
                       match=f"unrecognized loss type: {loss}."):
        get_z_encoder_for(X=X,
                          D_hat=D_hat,
                          n_atoms=N_ATOMS,
                          n_times_atom=N_TIMES_ATOM,
                          loss=loss,
                          n_jobs=2)
Esempio n. 8
0
def test_get_encoder_for_error_solver(X, D_hat, solver):
    """Tests for invalid values of `solver`."""

    with pytest.raises(ValueError,
                       match=f"unrecognized solver type: {solver}."):
        get_z_encoder_for(solver=solver,
                          X=X,
                          D_hat=D_hat,
                          n_atoms=N_ATOMS,
                          n_times_atom=N_TIMES_ATOM,
                          n_jobs=2)
Esempio n. 9
0
def test_get_encoder_for_dicodile_error_n_trials(solver, X, D_hat,
                                                 requires_dicodile):
    """Test for invalid n_trials value for dicodile backend."""

    with pytest.raises(AssertionError,
                       match="X should be a valid array of shape*"):
        get_z_encoder_for(solver=solver,
                          X=X,
                          D_hat=D_hat,
                          n_atoms=N_ATOMS,
                          n_times_atom=N_TIMES_ATOM,
                          n_jobs=2)
Esempio n. 10
0
def test_get_encoder_for_dicodile_error_feasible_ev(solver, X, D_hat,
                                                    requires_dicodile):
    """Test for invalid feasible_evaluation value for dicodile backend."""

    with pytest.raises(AssertionError,
                       match="DiCoDiLe requires feasible_evaluation=False."):
        get_z_encoder_for(solver=solver,
                          X=X,
                          D_hat=D_hat,
                          feasible_evaluation=True,
                          n_atoms=N_ATOMS,
                          n_times_atom=N_TIMES_ATOM,
                          n_jobs=2)
Esempio n. 11
0
def test_get_encoder_for_dicodile_error_uv_constraint(solver, X, D_hat,
                                                      requires_dicodile):
    """Test for invalid uv_constraint value for dicodile backend."""

    with pytest.raises(AssertionError,
                       match="DiCoDiLe requires uv_constraint=auto."):
        get_z_encoder_for(solver=solver,
                          X=X,
                          D_hat=D_hat,
                          uv_constraint='separate',
                          n_atoms=N_ATOMS,
                          n_times_atom=N_TIMES_ATOM,
                          n_jobs=2)
Esempio n. 12
0
def test_get_encoder_for_dicodile_error_loss_params(solver, X, D_hat,
                                                    requires_dicodile):
    """Test for invalid loss_params value for dicodile backend."""

    with pytest.raises(AssertionError,
                       match="DiCoDiLe requires loss_params=None."):
        get_z_encoder_for(solver=solver,
                          X=X,
                          D_hat=D_hat,
                          loss_params={},
                          n_atoms=N_ATOMS,
                          n_times_atom=N_TIMES_ATOM,
                          n_jobs=2)
Esempio n. 13
0
def test_get_encoder_for_dicodile_error_loss(solver, X, D_hat, loss,
                                             requires_dicodile):
    """Test for invalid loss value for dicodile backend."""

    with pytest.raises(
            AssertionError,
            match=("DiCoDiLe requires a l2 loss \\('dtw' passed\\).")):
        get_z_encoder_for(solver=solver,
                          X=X,
                          D_hat=D_hat,
                          loss=loss,
                          n_atoms=N_ATOMS,
                          n_times_atom=N_TIMES_ATOM,
                          n_jobs=2)
Esempio n. 14
0
def test_get_cost(solver, X, D_hat, requires_dicodile):
    """Test for valid values."""

    with get_z_encoder_for(solver=solver,
                           X=X,
                           D_hat=D_hat,
                           n_atoms=N_ATOMS,
                           n_times_atom=N_TIMES_ATOM,
                           n_jobs=2) as z_encoder:
        initial_cost = z_encoder.get_cost()

        z_encoder.compute_z()
        z_hat = z_encoder.get_z_hat()
        final_cost = z_encoder.get_cost()

        assert final_cost < initial_cost

        X_hat = construct_X_multi(z_hat, D_hat, n_channels=N_CHANNELS)
        cost = compute_objective(X=X,
                                 X_hat=X_hat,
                                 z_hat=z_hat,
                                 reg=0.1,
                                 D=D_hat)

        assert np.isclose(cost, final_cost)
Esempio n. 15
0
def test_get_encoder_for_dicodile(X, D_hat, solver, requires_dicodile):
    """Test for valid values for dicodile backend."""

    with get_z_encoder_for(solver=solver,
                           X=X,
                           D_hat=D_hat,
                           n_atoms=N_ATOMS,
                           n_times_atom=N_TIMES_ATOM,
                           n_jobs=2) as z_encoder:

        assert z_encoder is not None
Esempio n. 16
0
def test_compute_z(solver, X, D_hat, requires_dicodile):
    """Test for valid values."""

    with get_z_encoder_for(solver=solver,
                           X=X,
                           D_hat=D_hat,
                           n_atoms=N_ATOMS,
                           n_times_atom=N_TIMES_ATOM,
                           n_jobs=2) as z_encoder:
        z_encoder.compute_z()
        assert z_encoder.get_z_hat().any()
Esempio n. 17
0
def test_add_one_atom(X, D_hat):
    """Test for valid values."""

    with get_z_encoder_for(X=X,
                           D_hat=D_hat,
                           n_atoms=N_ATOMS,
                           n_times_atom=N_TIMES_ATOM,
                           n_jobs=2) as z_encoder:
        new_atom = np.random.rand(N_CHANNELS + N_TIMES_ATOM)
        z_encoder.add_one_atom(new_atom)
        n_atoms_plus_one = z_encoder.D_hat.shape[0]
        assert n_atoms_plus_one == N_ATOMS + 1
Esempio n. 18
0
def test_compute_z_partial(X, D_hat, n_trials, rng):
    """Test for valid values."""

    with get_z_encoder_for(X=X,
                           D_hat=D_hat,
                           n_atoms=N_ATOMS,
                           n_times_atom=N_TIMES_ATOM,
                           n_jobs=2) as z_encoder:

        i0 = rng.choice(n_trials, 1, replace=False)
        z_encoder.compute_z_partial(i0)
        assert z_encoder.get_z_hat().any()
Esempio n. 19
0
def test_get_sufficient_statistics_partial_error(X, D_hat):
    """Test for invalid call to function."""

    z_encoder = get_z_encoder_for(X=X,
                                  D_hat=D_hat,
                                  n_atoms=N_ATOMS,
                                  n_times_atom=N_TIMES_ATOM,
                                  n_jobs=2)

    # test before calling compute_z_partial
    with pytest.raises(AssertionError,
                       match="compute_z_partial should be called.*"):
        z_encoder.get_sufficient_statistics_partial()
Esempio n. 20
0
def test_get_sufficient_statistics_error(solver, X, D_hat, requires_dicodile):
    """Test for invalid call to function."""

    z_encoder = get_z_encoder_for(solver=solver,
                                  X=X,
                                  D_hat=D_hat,
                                  n_atoms=N_ATOMS,
                                  n_times_atom=N_TIMES_ATOM,
                                  n_jobs=2)

    # test before calling compute_z
    with pytest.raises(AssertionError, match="compute_z should be called.*"):
        z_encoder.get_sufficient_statistics()
Esempio n. 21
0
def test_get_sufficient_statistics_partial(X, D_hat, n_trials, rng):
    """Test for valid values."""

    z_encoder = get_z_encoder_for(X=X,
                                  D_hat=D_hat,
                                  n_atoms=N_ATOMS,
                                  n_times_atom=N_TIMES_ATOM,
                                  n_jobs=2)

    i0 = rng.choice(n_trials, 1, replace=False)
    z_encoder.compute_z_partial(i0)

    ztz_i0, ztX_i0 = z_encoder.get_sufficient_statistics_partial()
    assert ztz_i0 is not None and ztX_i0 is not None
Esempio n. 22
0
def test_get_encoder_for_alphacsc(X, solver, D_hat, loss, uv_constraint,
                                  feasible_evaluation):
    """Test for valid values for alphacsc backend."""

    with get_z_encoder_for(solver=solver,
                           X=X,
                           D_hat=D_hat,
                           n_atoms=N_ATOMS,
                           n_times_atom=N_TIMES_ATOM,
                           loss=loss,
                           uv_constraint=uv_constraint,
                           feasible_evaluation=feasible_evaluation,
                           n_jobs=2) as z_encoder:

        assert z_encoder is not None
Esempio n. 23
0
def test_get_sufficient_statistics(solver, X, D_hat, requires_dicodile):
    """Test for valid values."""

    z_encoder = get_z_encoder_for(solver=solver,
                                  X=X,
                                  D_hat=D_hat,
                                  n_atoms=N_ATOMS,
                                  n_times_atom=N_TIMES_ATOM,
                                  n_jobs=2)

    z_encoder.compute_z()
    z_hat = z_encoder.get_z_hat()

    ztz, ztX = z_encoder.get_sufficient_statistics()
    assert ztz is not None and np.allclose(ztz, compute_ztz(
        z_hat, N_TIMES_ATOM))

    assert ztX is not None and np.allclose(ztX, compute_ztX(z_hat, X))