def test_robust_label_encoder_fill_label_value():
    y = np.array([1, 1, 0, 1, 1])
    enc = RobustLabelEncoder(labels=[1], fill_label_value=0, include_unseen_class=True)
    enc.fit(y)
    np.testing.assert_array_equal(enc.get_classes(), [1, 0])
    y_transform = enc.transform(y)
    np.testing.assert_array_equal(y_transform, [0, 0, 1, 0, 0])
    np.testing.assert_array_equal(enc.inverse_transform(y_transform), y)

    # Test that fit_transform has the same behavior
    enc = RobustLabelEncoder(labels=[1], fill_label_value=0)
    y_transform = enc.fit_transform(y)
    np.testing.assert_array_equal(enc.get_classes(), [1])
    np.testing.assert_array_equal(y_transform, [0, 0, 1, 0, 0])
    np.testing.assert_array_equal(enc.inverse_transform(y_transform), y)
def test_robust_label_encoder_unsorted_labels_warning(labels):
    enc = RobustLabelEncoder(labels=labels)
    with pytest.warns(UserWarning):
        enc.fit([labels[2], labels[0]])

    np.testing.assert_array_equal(list(enc.classes_), sorted(labels))
    np.testing.assert_array_equal(enc.get_classes(), sorted(labels))
    np.testing.assert_array_equal(enc.transform([labels[1], labels[2], "173"]), [2, 1, 3])

    # Test that fit_transform has the same behavior
    enc = RobustLabelEncoder(labels=labels)
    with pytest.warns(UserWarning):
        y_transformed = enc.fit_transform([labels[1], labels[2], "173"])

    np.testing.assert_array_equal(list(enc.classes_), sorted(labels))
    np.testing.assert_array_equal(y_transformed, [2, 1, 3])

    # Test fill_label_value is not sorted when include_unseen_class is True
    enc = RobustLabelEncoder(labels=labels, fill_label_value="-99", include_unseen_class=True)
    with pytest.warns(UserWarning):
        enc.fit([labels[2], labels[0]])
    np.testing.assert_array_equal(enc.get_classes(), sorted(labels) + ["-99"])
def test_robust_label_encoder():
    enc = RobustLabelEncoder()
    enc.fit(X[:, 0])

    np.testing.assert_array_equal(enc.classes_, ["apple", "banana", "hot dog"])
    np.testing.assert_array_equal(enc.get_classes(), ["apple", "banana", "hot dog"])
    np.testing.assert_array_equal(enc.transform([]), [])
    np.testing.assert_array_equal(enc.transform(["hot dog", "banana", "hot dog"]), [2, 1, 2])
    np.testing.assert_array_equal(enc.transform(["hot dog", "llama"]), [2, 3])
    np.testing.assert_array_equal(enc.inverse_transform([0, 2]), ["apple", "hot dog"])
    np.testing.assert_array_equal(enc.inverse_transform([0, 10]), ["apple", "<unseen_label>"])

    np.testing.assert_array_equal(enc.fit_transform(X[:, 0]), [2, 2, 0, 2, 2, 1])
def test_robust_label_encoder_sorted_labels(labels):
    enc = RobustLabelEncoder(labels=labels)
    enc.fit([labels[1], labels[0]])

    assert_array_equal(list(enc.classes_), labels)
    assert_array_equal(enc.get_classes(), labels)
    assert_array_equal(enc.transform([labels[2], labels[1], "173"]), [2, 1, 3])

    # Test that fit_transform has the same behavior
    enc = RobustLabelEncoder(labels=labels)
    y_transformed = enc.fit_transform([labels[2], labels[1], "173"])

    assert_array_equal(list(enc.classes_), labels)
    assert_array_equal(y_transformed, [2, 1, 3])
def test_robust_label_encoder_error_unknown():
    with pytest.raises(ValueError):
        enc = RobustLabelEncoder(fill_unseen_labels=False)
        enc.fit(X[:, 0])
        assert_array_equal(enc.get_classes(), ["apple", "banana", "hot dog"])
        enc.transform(["eggplant"])