def test_update_n_basis_modes_unfit_basis(basis, data_binary_classification): x, y, _ = data_binary_classification n_basis_modes = 5 model = SSPOC(basis=basis()) model.update_n_basis_modes(n_basis_modes, (x, y), quiet=True) assert model.basis_matrix_inverse_.shape[0] == n_basis_modes
def test_update_n_basis_modes_shape(basis, data_binary_classification): x, y, _ = data_binary_classification n_basis_modes_init = 10 model = SSPOC(basis=basis(n_basis_modes=n_basis_modes_init)) model.fit(x, y, quiet=True) assert model.basis.n_basis_modes == n_basis_modes_init assert model.basis_matrix_inverse_.shape[0] == n_basis_modes_init n_basis_modes = 5 model.update_n_basis_modes(n_basis_modes, xy=(x, y), quiet=True) assert model.basis.n_basis_modes == n_basis_modes_init assert model.basis_matrix_inverse_.shape[0] == n_basis_modes
def test_update_n_basis_modes_errors(basis, data_binary_classification): x, y, _ = data_binary_classification n_basis_modes = 5 model = SSPOC(basis=basis(n_basis_modes=n_basis_modes)) model.fit(x, y, quiet=True) with pytest.raises(ValueError): model.update_n_basis_modes(0, xy=(x, y)) with pytest.raises(ValueError): model.update_n_basis_modes("5", xy=(x, y)) with pytest.raises(ValueError): model.update_n_basis_modes(x.shape[0] + 1, xy=(x, y))