def test_robust_ordinal_encoding_inverse_transform(unknown_as_nan):
    encoder = RobustOrdinalEncoder(unknown_as_nan=unknown_as_nan)
    encoder.fit(ordinal_data)
    test_data = np.concatenate([ordinal_data, np.array([["waffle", 1213, None]])], axis=0)
    encoded = encoder.transform(test_data)
    reverse = encoder.inverse_transform(encoded)
    assert np.array_equal(ordinal_data, reverse[:-1])
    assert all([x is None for x in reverse[-1]])
def test_robust_ordinal_encoding_inverse_transform_floatkeys():
    encoder = RobustOrdinalEncoder()
    data = np.arange(9).astype(np.float32).reshape((3, 3))
    encoder.fit(data)
    test_data = data + 3
    encoded = encoder.transform(test_data)
    reverse = encoder.inverse_transform(encoded)
    assert reverse.dtype == object
    assert np.array_equal(data[1:], reverse[:-1])
    assert all([x is None for x in reverse[-1]])
def test_robust_ordinal_encoding_inverse_transform(unknown_as_nan):
    encoder = RobustOrdinalEncoder(unknown_as_nan=unknown_as_nan)
    encoder.fit(ordinal_data)
    test_data = np.concatenate([ordinal_data, np.array([["waffle", 1213, None]])], axis=0)
    encoded = encoder.transform(test_data)
    reverse = encoder.inverse_transform(encoded)
    assert np.array_equal(ordinal_data, reverse[:-1])
    assert all([x is None for x in reverse[-1]])

    # Test where some categories are below the threshold
    encoder = RobustOrdinalEncoder(unknown_as_nan=unknown_as_nan, threshold=2)
    encoder.fit(ordinal_data)
    encoded = encoder.transform(test_data)
    reverse = encoder.inverse_transform(encoded)
    assert sum([i is None for i in reverse[:, 0]]) == 3
    assert sum([i is None for i in reverse[:, 1]]) == 2
    assert sum([i is None for i in reverse[:, 2]]) == 2

    # Test where all categories are below the threshold
    encoder = RobustOrdinalEncoder(unknown_as_nan=unknown_as_nan, threshold=10)
    encoder.fit(ordinal_data)
    encoded = encoder.transform(test_data)
    reverse = encoder.inverse_transform(encoded)
    assert sum(([i is None for i in reverse.flatten()])) == reverse.size