def test_get_encoder_for_error_reg(X, D_hat): """Tests for invalid value of `reg`.""" with pytest.raises(AssertionError, match="reg value cannot be None."): get_z_encoder_for(X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, reg=None, n_jobs=2)
def test_get_encoder_for_error_solver_kwargs(X, D_hat): """Tests for invalid value of `solver_kwargs`.""" with pytest.raises(AssertionError, match=".*solver_kwargs should.*"): get_z_encoder_for(solver_kwargs=None, X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2)
def test_get_encoder_for_error_D_hat(X, D_init): """Tests for invalid values of `D_hat`.""" with pytest.raises(AssertionError, match="D_hat should be a valid array of shape.*"): get_z_encoder_for(X=X, D_hat=D_init, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2)
def test_get_encoder_for_dicodile_error_rank1(X, D_hat, requires_dicodile): """Test for invalid rank1 value for dicodile backend.""" with pytest.raises(AssertionError): get_z_encoder_for(solver='dicodile', X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2)
def test_get_encoder_for_error_uv_constraint(X, D_hat, uv_constraint): """Tests for invalid values of `uv_constraint`.""" with pytest.raises(AssertionError, match="unrecognized uv_constraint type.*"): get_z_encoder_for(X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, uv_constraint=uv_constraint, n_jobs=2)
def test_get_encoder_for_error_loss_params(X, D_hat): """Tests for invalid value of `loss_params`.""" with pytest.raises(AssertionError, match="loss_params should be a valid dict or None."): get_z_encoder_for(X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, loss_params=42, n_jobs=2)
def test_get_encoder_for_error_loss(X, D_hat, loss): """Tests for invalid values of `loss`.""" with pytest.raises(AssertionError, match=f"unrecognized loss type: {loss}."): get_z_encoder_for(X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, loss=loss, n_jobs=2)
def test_get_encoder_for_error_solver(X, D_hat, solver): """Tests for invalid values of `solver`.""" with pytest.raises(ValueError, match=f"unrecognized solver type: {solver}."): get_z_encoder_for(solver=solver, X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2)
def test_get_encoder_for_dicodile_error_n_trials(solver, X, D_hat, requires_dicodile): """Test for invalid n_trials value for dicodile backend.""" with pytest.raises(AssertionError, match="X should be a valid array of shape*"): get_z_encoder_for(solver=solver, X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2)
def test_get_encoder_for_dicodile_error_feasible_ev(solver, X, D_hat, requires_dicodile): """Test for invalid feasible_evaluation value for dicodile backend.""" with pytest.raises(AssertionError, match="DiCoDiLe requires feasible_evaluation=False."): get_z_encoder_for(solver=solver, X=X, D_hat=D_hat, feasible_evaluation=True, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2)
def test_get_encoder_for_dicodile_error_uv_constraint(solver, X, D_hat, requires_dicodile): """Test for invalid uv_constraint value for dicodile backend.""" with pytest.raises(AssertionError, match="DiCoDiLe requires uv_constraint=auto."): get_z_encoder_for(solver=solver, X=X, D_hat=D_hat, uv_constraint='separate', n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2)
def test_get_encoder_for_dicodile_error_loss_params(solver, X, D_hat, requires_dicodile): """Test for invalid loss_params value for dicodile backend.""" with pytest.raises(AssertionError, match="DiCoDiLe requires loss_params=None."): get_z_encoder_for(solver=solver, X=X, D_hat=D_hat, loss_params={}, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2)
def test_get_encoder_for_dicodile_error_loss(solver, X, D_hat, loss, requires_dicodile): """Test for invalid loss value for dicodile backend.""" with pytest.raises( AssertionError, match=("DiCoDiLe requires a l2 loss \\('dtw' passed\\).")): get_z_encoder_for(solver=solver, X=X, D_hat=D_hat, loss=loss, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2)
def test_get_cost(solver, X, D_hat, requires_dicodile): """Test for valid values.""" with get_z_encoder_for(solver=solver, X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2) as z_encoder: initial_cost = z_encoder.get_cost() z_encoder.compute_z() z_hat = z_encoder.get_z_hat() final_cost = z_encoder.get_cost() assert final_cost < initial_cost X_hat = construct_X_multi(z_hat, D_hat, n_channels=N_CHANNELS) cost = compute_objective(X=X, X_hat=X_hat, z_hat=z_hat, reg=0.1, D=D_hat) assert np.isclose(cost, final_cost)
def test_get_encoder_for_dicodile(X, D_hat, solver, requires_dicodile): """Test for valid values for dicodile backend.""" with get_z_encoder_for(solver=solver, X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2) as z_encoder: assert z_encoder is not None
def test_compute_z(solver, X, D_hat, requires_dicodile): """Test for valid values.""" with get_z_encoder_for(solver=solver, X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2) as z_encoder: z_encoder.compute_z() assert z_encoder.get_z_hat().any()
def test_add_one_atom(X, D_hat): """Test for valid values.""" with get_z_encoder_for(X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2) as z_encoder: new_atom = np.random.rand(N_CHANNELS + N_TIMES_ATOM) z_encoder.add_one_atom(new_atom) n_atoms_plus_one = z_encoder.D_hat.shape[0] assert n_atoms_plus_one == N_ATOMS + 1
def test_compute_z_partial(X, D_hat, n_trials, rng): """Test for valid values.""" with get_z_encoder_for(X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2) as z_encoder: i0 = rng.choice(n_trials, 1, replace=False) z_encoder.compute_z_partial(i0) assert z_encoder.get_z_hat().any()
def test_get_sufficient_statistics_partial_error(X, D_hat): """Test for invalid call to function.""" z_encoder = get_z_encoder_for(X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2) # test before calling compute_z_partial with pytest.raises(AssertionError, match="compute_z_partial should be called.*"): z_encoder.get_sufficient_statistics_partial()
def test_get_sufficient_statistics_error(solver, X, D_hat, requires_dicodile): """Test for invalid call to function.""" z_encoder = get_z_encoder_for(solver=solver, X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2) # test before calling compute_z with pytest.raises(AssertionError, match="compute_z should be called.*"): z_encoder.get_sufficient_statistics()
def test_get_sufficient_statistics_partial(X, D_hat, n_trials, rng): """Test for valid values.""" z_encoder = get_z_encoder_for(X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2) i0 = rng.choice(n_trials, 1, replace=False) z_encoder.compute_z_partial(i0) ztz_i0, ztX_i0 = z_encoder.get_sufficient_statistics_partial() assert ztz_i0 is not None and ztX_i0 is not None
def test_get_encoder_for_alphacsc(X, solver, D_hat, loss, uv_constraint, feasible_evaluation): """Test for valid values for alphacsc backend.""" with get_z_encoder_for(solver=solver, X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, loss=loss, uv_constraint=uv_constraint, feasible_evaluation=feasible_evaluation, n_jobs=2) as z_encoder: assert z_encoder is not None
def test_get_sufficient_statistics(solver, X, D_hat, requires_dicodile): """Test for valid values.""" z_encoder = get_z_encoder_for(solver=solver, X=X, D_hat=D_hat, n_atoms=N_ATOMS, n_times_atom=N_TIMES_ATOM, n_jobs=2) z_encoder.compute_z() z_hat = z_encoder.get_z_hat() ztz, ztX = z_encoder.get_sufficient_statistics() assert ztz is not None and np.allclose(ztz, compute_ztz( z_hat, N_TIMES_ATOM)) assert ztX is not None and np.allclose(ztX, compute_ztX(z_hat, X))