def test_one_hot_encoder_handle_unknown(): X = np.array([[0, 2, 1], [1, 0, 3], [1, 0, 2]]) X2 = np.array([[4, 1, 1]]) # Test that one hot encoder raises error for unknown features # present during transform. oh = OneHotEncoder(handle_unknown='error') oh.fit(X) with pytest.raises(ValueError, match='Found unknown categories'): oh.transform(X2) # Test the ignore option, ignores unknown features (giving all 0's) oh = OneHotEncoder(handle_unknown='ignore') oh.fit(X) X2_passed = X2.copy() assert_array_equal( oh.transform(X2_passed).toarray(), np.array([[0., 0., 0., 0., 1., 0., 0.]])) # ensure transformed data was not modified in place assert_allclose(X2, X2_passed) # Raise error if handle_unknown is neither ignore or error. oh = OneHotEncoder(handle_unknown='42') with pytest.raises(ValueError, match='handle_unknown should be either'): oh.fit(X)
def test_one_hot_encoder_not_fitted(): X = np.array([['a'], ['b']]) enc = OneHotEncoder(categories=['a', 'b']) msg = ("This OneHotEncoder instance is not fitted yet. " "Call 'fit' with appropriate arguments before using this " "estimator.") with pytest.raises(NotFittedError, match=msg): enc.transform(X)
def test_one_hot_encoder_diff_n_features(): X = np.array([[0, 2, 1], [1, 0, 3], [1, 0, 2]]) X2 = np.array([[1, 0]]) enc = OneHotEncoder() enc.fit(X) err_msg = ("The number of features in X is different to the number of " "features of the fitted data.") with pytest.raises(ValueError, match=err_msg): enc.transform(X2)
def test_encoder_dtypes_pandas(): # check dtype (similar to test_categorical_encoder_dtypes for dataframes) pd = pytest.importorskip('pandas') enc = OneHotEncoder(categories='auto') exp = np.array([[1., 0., 1., 0., 1., 0.], [0., 1., 0., 1., 0., 1.]], dtype='float64') X = pd.DataFrame({'A': [1, 2], 'B': [3, 4], 'C': [5, 6]}, dtype='int64') enc.fit(X) assert all([enc.categories_[i].dtype == 'int64' for i in range(2)]) assert_array_equal(enc.transform(X).toarray(), exp) X = pd.DataFrame({'A': [1, 2], 'B': ['a', 'b'], 'C': [3., 4.]}) X_type = [X['A'].dtype, X['B'].dtype, X['C'].dtype] enc.fit(X) assert all([enc.categories_[i].dtype == X_type[i] for i in range(3)]) assert_array_equal(enc.transform(X).toarray(), exp)
def test_one_hot_encoder_raise_missing(X, as_data_frame, handle_unknown): if as_data_frame: pd = pytest.importorskip('pandas') X = pd.DataFrame(X) ohe = OneHotEncoder(categories='auto', handle_unknown=handle_unknown) with pytest.raises(ValueError, match="Input contains NaN"): ohe.fit(X) with pytest.raises(ValueError, match="Input contains NaN"): ohe.fit_transform(X) if as_data_frame: X_partial = X.iloc[:1, :] else: X_partial = X[:1, :] ohe.fit(X_partial) with pytest.raises(ValueError, match="Input contains NaN"): ohe.transform(X)
def test_encoder_dtypes(): # check that dtypes are preserved when determining categories enc = OneHotEncoder(categories='auto') exp = np.array([[1., 0., 1., 0.], [0., 1., 0., 1.]], dtype='float64') for X in [np.array([[1, 2], [3, 4]], dtype='int64'), np.array([[1, 2], [3, 4]], dtype='float64'), np.array([['a', 'b'], ['c', 'd']]), # string dtype np.array([[1, 'a'], [3, 'b']], dtype='object')]: enc.fit(X) assert all([enc.categories_[i].dtype == X.dtype for i in range(2)]) assert_array_equal(enc.transform(X).toarray(), exp) X = [[1, 2], [3, 4]] enc.fit(X) assert all([np.issubdtype(enc.categories_[i].dtype, np.integer) for i in range(2)]) assert_array_equal(enc.transform(X).toarray(), exp) X = [[1, 'a'], [3, 'b']] enc.fit(X) assert all([enc.categories_[i].dtype == 'object' for i in range(2)]) assert_array_equal(enc.transform(X).toarray(), exp)
def test_one_hot_encoder_handle_unknown_strings(): X = np.array(['11111111', '22', '333', '4444']).reshape((-1, 1)) X2 = np.array(['55555', '22']).reshape((-1, 1)) # Non Regression test for the issue #12470 # Test the ignore option, when categories are numpy string dtype # particularly when the known category strings are larger # than the unknown category strings oh = OneHotEncoder(handle_unknown='ignore') oh.fit(X) X2_passed = X2.copy() assert_array_equal( oh.transform(X2_passed).toarray(), np.array([[0., 0., 0., 0.], [0., 1., 0., 0.]])) # ensure transformed data was not modified in place assert_array_equal(X2, X2_passed)