def test_metadata_with_point_in_field_name(): dataset = "test_metadata_with_point_in_field_name" assert client.delete(f"/api/datasets/{dataset}").status_code == 200 response = client.post( f"/api/datasets/{dataset}/TextClassification:bulk", data=TextClassificationBulkData( records=[ TextClassificationRecord( **{ "id": 0, "inputs": {"text": "Esto es un ejemplo de texto"}, "metadata": {"field.one": 1, "field.two": 2}, } ), TextClassificationRecord( **{ "id": 1, "inputs": {"text": "This is an simple text example"}, "metadata": {"field.one": 1, "field.two": 2}, } ), ], ).json(by_alias=True), ) response = client.post( f"/api/datasets/{dataset}/TextClassification:search?limit=0", json={}, ) results = TextClassificationSearchResults.parse_obj(response.json()) assert "field.one" in results.aggregations.metadata assert results.aggregations.metadata.get("field.one", {})["1"] == 2 assert results.aggregations.metadata.get("field.two", {})["2"] == 2
def test_confidence_integrity(): data = { "multi_label": False, "inputs": { "data": "My cool data" }, "prediction": { "agent": "test", "labels": [ { "class": "A", "confidence": 0.3 }, { "class": "B", "confidence": 0.9 }, ], }, } try: TextClassificationRecord.parse_obj(data) except ValidationError as e: assert "Wrong score distributions" in e.json() data["multi_label"] = True record = TextClassificationRecord.parse_obj(data) assert record is not None data["multi_label"] = False data["prediction"]["labels"] = [ { "class": "B", "confidence": 0.9 }, ] record = TextClassificationRecord.parse_obj(data) assert record is not None data["prediction"]["labels"] = [ { "class": "B", "confidence": 0.10000000012 }, { "class": "B", "confidence": 0.90000000002 }, ] record = TextClassificationRecord.parse_obj(data) assert record is not None
def test_too_long_label(): with pytest.raises(ValidationError, match="exceeds max length"): TextClassificationRecord.parse_obj({ "inputs": { "text": "bogh" }, "prediction": { "agent": "test", "labels": [{ "class": "a" * 1000 }], }, })
def test_wrong_text_query(): dataset = "test_wrong_text_query" assert client.delete(f"/api/datasets/{dataset}").status_code == 200 response = client.post( f"/api/datasets/{dataset}/TextClassification:bulk", data=TextClassificationBulkData( records=[ TextClassificationRecord( **{ "id": 0, "inputs": {"text": "Esto es un ejemplo de texto"}, "metadata": {"field.one": 1, "field.two": 2}, } ), ], ).json(by_alias=True), ) response = client.post( f"/api/datasets/{dataset}/TextClassification:search", json=TextClassificationSearchRequest( query=TextClassificationQuery(query_text="!") ).dict(), ) assert response.status_code == 400 assert response.json()["detail"] == "Failed to parse query [!]"
def test_create_records_for_text_classification(): dataset = "test_create_records_for_text_classification" assert client.delete(f"/api/datasets/{dataset}").status_code == 200 tags = {"env": "test", "class": "text classification"} metadata = {"config": {"the": "config"}} classification_bulk = TextClassificationBulkData( tags=tags, metadata=metadata, records=[ TextClassificationRecord( **{ "id": 0, "inputs": { "data": "my data" }, "prediction": { "agent": "test", "labels": [ { "class": "Test", "confidence": 0.3 }, { "class": "Mocking", "confidence": 0.7 }, ], }, }) ], ) response = client.post( f"/api/datasets/{dataset}/TextClassification:bulk", json=classification_bulk.dict(by_alias=True), ) assert response.status_code == 200 bulk_response = BulkResponse.parse_obj(response.json()) assert bulk_response.dataset == dataset assert bulk_response.failed == 0 assert bulk_response.processed == 1 response = client.get(f"/api/datasets/{dataset}") assert response.status_code == 200 created_dataset = Dataset.parse_obj(response.json()) assert created_dataset.tags == tags assert created_dataset.metadata == metadata response = client.post( f"/api/datasets/{dataset}/TextClassification:search", json={}) assert response.status_code == 200 results = TextClassificationSearchResults.parse_obj(response.json()) assert results.total == 1 assert results.aggregations.predicted_as == {"Mocking": 1} assert results.aggregations.status == {"Default": 1} assert results.aggregations.score assert results.aggregations.predicted == {}
def test_created_record_with_default_status(): data = { "inputs": { "data": "My cool data" }, } record = TextClassificationRecord.parse_obj(data) assert record.status == TaskStatus.default
def test_include_event_timestamp(): dataset = "test_include_event_timestamp" assert client.delete(f"/api/datasets/{dataset}").status_code == 200 response = client.post( f"/api/datasets/{dataset}/TextClassification:bulk", data=TextClassificationBulkData( tags={ "env": "test", "class": "text classification" }, metadata={ "config": { "the": "config" } }, records=[ TextClassificationRecord( **{ "id": i, "inputs": { "data": "my data" }, "event_timestamp": datetime.utcnow(), "prediction": { "agent": "test", "labels": [ { "class": "Test", "confidence": 0.3 }, { "class": "Mocking", "confidence": 0.7 }, ], }, }) for i in range(0, 100) ], ).json(by_alias=True), ) bulk_response = BulkResponse.parse_obj(response.json()) assert bulk_response.processed == 100 response = client.post( f"/api/datasets/{dataset}/TextClassification:search?from=10", json={}, ) results = TextClassificationSearchResults.parse_obj(response.json()) assert results.total == 100 assert all( map(lambda record: record.event_timestamp is not None, results.records))
def test_flatten_inputs(): data = { "inputs": { "mail": { "subject": "The mail subject", "body": "This is a large text body" } } } record = TextClassificationRecord.parse_obj(data) assert list(record.inputs.keys()) == ["mail.subject", "mail.body"]
def test_too_long_metadata(): record = TextClassificationRecord.parse_obj({ "inputs": { "text": "bogh" }, "metadata": { "too_long": "a" * 1000 }, }) assert len(record.metadata["too_long"]) == MAX_KEYWORD_LENGTH
def test_predicted_as_with_no_labels(): data = { "inputs": { "text": "The input text" }, "prediction": { "agent": "test", "labels": [] }, } record = TextClassificationRecord(**data) assert record.predicted_as == []
def test_disable_aggregations_when_scroll(): dataset = "test_disable_aggregations_when_scroll" assert client.delete(f"/api/datasets/{dataset}").status_code == 200 response = client.post( f"/api/datasets/{dataset}/TextClassification:bulk", json=TextClassificationBulkData( tags={ "env": "test", "class": "text classification" }, metadata={ "config": { "the": "config" } }, records=[ TextClassificationRecord( **{ "id": i, "inputs": { "data": "my data" }, "prediction": { "agent": "test", "labels": [ { "class": "Test", "confidence": 0.3 }, { "class": "Mocking", "confidence": 0.7 }, ], }, }) for i in range(0, 100) ], ).dict(by_alias=True), ) bulk_response = BulkResponse.parse_obj(response.json()) assert bulk_response.processed == 100 response = client.post( f"/api/datasets/{dataset}/TextClassification:search?from=10", json={}, ) results = TextClassificationSearchResults.parse_obj(response.json()) assert results.total == 100 assert results.aggregations is None
def test_words_cloud(): dataset = "test_language_detection" assert client.delete(f"/api/datasets/{dataset}").status_code == 200 response = client.post( f"/api/datasets/{dataset}/TextClassification:bulk", data=TextClassificationBulkData( records=[ TextClassificationRecord( **{ "id": 0, "inputs": {"text": "Esto es un ejemplo de texto"}, } ), TextClassificationRecord( **{ "id": 1, "inputs": {"text": "This is an simple text example"}, } ), TextClassificationRecord( **{ "id": 2, "inputs": {"text": "C'est nes pas une pipe"}, } ), ], ).json(by_alias=True), ) BulkResponse.parse_obj(response.json()) response = client.post( f"/api/datasets/{dataset}/TextClassification:search", json={}, ) results = TextClassificationSearchResults.parse_obj(response.json()) assert results.aggregations.words is not None
def create_some_data_for_text_classification(name: str, n: int): records = [ TextClassificationRecord(**data) for idx in range(0, n or 10, 2) for data in [ { "id": idx, "inputs": {"data": "my data"}, "multi_label": True, "metadata": {"field_one": "value one", "field_two": "value 2"}, "status": TaskStatus.validated, "annotation": { "agent": "test", "labels": [ {"class": "Test"}, {"class": "Mocking"}, ], }, }, { "id": idx + 1, "inputs": {"data": "my data"}, "multi_label": True, "metadata": {"field_one": "another value one", "field_two": "value 2"}, "status": TaskStatus.validated, "prediction": { "agent": "test", "labels": [ {"class": "NoClass"}, ], }, "annotation": { "agent": "test", "labels": [ {"class": "Test"}, ], }, }, ] ] client.post( f"/api/datasets/{name}/{TaskType.text_classification}:bulk", json=TextClassificationBulkData( tags={"env": "test", "class": "text classification"}, metadata={"config": {"the": "config"}}, records=records, ).dict(by_alias=True), )
def test_predicted_ok_for_multilabel_unordered(): record = TextClassificationRecord( inputs={"text": "The text"}, prediction=TextClassificationAnnotation( agent="test", labels=[ ClassPrediction(class_label="B"), ClassPrediction(class_label="C", score=0.3), ClassPrediction(class_label="A"), ], ), annotation=TextClassificationAnnotation( agent="test", labels=[ ClassPrediction(class_label="A"), ClassPrediction(class_label="B") ], ), multi_label=True, ) assert record.predicted == PredictionStatus.OK
def test_sort_by_id_as_default(): dataset = "test_sort_by_id_as_default" assert client.delete(f"/api/datasets/{dataset}").status_code == 200 response = client.post( f"/api/datasets/{dataset}/TextClassification:bulk", json=TextClassificationBulkData( records=[ TextClassificationRecord( **{ "id": i, "inputs": {"data": "my data"}, "metadata": {"s": "value"}, } ) for i in range(0, 100) ], ).dict(by_alias=True), ) response = client.post( f"/api/datasets/{dataset}/TextClassification:search?from=0&limit=10", json={}, ) results = TextClassificationSearchResults.parse_obj(response.json()) assert results.total == 100 assert list(map(lambda r: r.id, results.records)) == [ 0, 1, 10, 11, 12, 13, 14, 15, 16, 17, ]
def test_partial_record_update(): name = "test_partial_record_update" assert client.delete(f"/api/datasets/{name}").status_code == 200 record = TextClassificationRecord( **{ "id": 1, "inputs": {"text": "This is a text, oh yeah!"}, "prediction": { "agent": "test", "labels": [ {"class": "Positive", "confidence": 0.6}, {"class": "Negative", "confidence": 0.3}, {"class": "Other", "confidence": 0.1}, ], }, } ) bulk = TextClassificationBulkData( records=[record], ) response = client.post( f"/api/datasets/{name}/TextClassification:bulk", json=bulk.dict(by_alias=True), ) assert response.status_code == 200 bulk_response = BulkResponse.parse_obj(response.json()) assert bulk_response.failed == 0 assert bulk_response.processed == 1 record.annotation = TextClassificationAnnotation.parse_obj( { "agent": "gold_standard", "labels": [{"class": "Positive"}], } ) bulk.records = [record] client.post( f"/api/datasets/{name}/TextClassification:bulk", json=bulk.dict(by_alias=True), ) response = client.post( f"/api/datasets/{name}/TextClassification:search", json={ "query": TextClassificationQuery(predicted=PredictionStatus.OK).dict( by_alias=True ), }, ) assert response.status_code == 200 results = TextClassificationSearchResults.parse_obj(response.json()) assert results.total == 1 first_record = results.records[0] assert first_record.last_updated is not None first_record.last_updated = None assert TextClassificationRecord( **first_record.dict(by_alias=True, exclude_none=True) ) == TextClassificationRecord( **{ "id": 1, "inputs": {"text": "This is a text, oh yeah!"}, "prediction": { "agent": "test", "labels": [ {"class": "Positive", "confidence": 0.6}, {"class": "Negative", "confidence": 0.3}, {"class": "Other", "confidence": 0.1}, ], }, "annotation": { "agent": "gold_standard", "labels": [{"class": "Positive"}], }, } )
def test_create_records_for_text_classification_with_multi_label(): dataset = "test_create_records_for_text_classification_with_multi_label" assert client.delete(f"/api/datasets/{dataset}").status_code == 200 records = [ TextClassificationRecord.parse_obj(data) for data in [ { "id": 0, "inputs": {"data": "my data"}, "multi_label": True, "metadata": {"field_one": "value one", "field_two": "value 2"}, "prediction": { "agent": "test", "labels": [ {"class": "Test", "confidence": 0.6}, {"class": "Mocking", "confidence": 0.7}, {"class": "NoClass", "confidence": 0.2}, ], }, }, { "id": 1, "inputs": {"data": "my data"}, "multi_label": True, "metadata": {"field_one": "another value one", "field_two": "value 2"}, "prediction": { "agent": "test", "labels": [ {"class": "Test", "confidence": 0.6}, {"class": "Mocking", "confidence": 0.7}, {"class": "NoClass", "confidence": 0.2}, ], }, }, ] ] response = client.post( f"/api/datasets/{dataset}/TextClassification:bulk", json=TextClassificationBulkData( tags={"env": "test", "class": "text classification"}, metadata={"config": {"the": "config"}}, records=records, ).dict(by_alias=True), ) assert response.status_code == 200, response.json() bulk_response = BulkResponse.parse_obj(response.json()) assert bulk_response.dataset == dataset assert bulk_response.failed == 0 assert bulk_response.processed == 2 response = client.post( f"/api/datasets/{dataset}/TextClassification:bulk", json=TextClassificationBulkData( tags={"new": "tag"}, metadata={"new": {"metadata": "value"}}, records=records, ).dict(by_alias=True), ) get_dataset = Dataset.parse_obj(client.get(f"/api/datasets/{dataset}").json()) assert get_dataset.tags == { "env": "test", "class": "text classification", "new": "tag", } assert get_dataset.metadata == { "config": {"the": "config"}, "new": {"metadata": "value"}, } assert response.status_code == 200, response.json() response = client.post( f"/api/datasets/{dataset}/TextClassification:search", json={} ) assert response.status_code == 200 results = TextClassificationSearchResults.parse_obj(response.json()) assert results.total == 2 assert results.aggregations.predicted_as == {"Mocking": 2, "Test": 2} assert results.records[0].predicted is None
def test_prediction_ok_cases(): data = { "multi_label": True, "inputs": { "data": "My cool data" }, "prediction": { "agent": "test", "labels": [ { "class": "A", "confidence": 0.3 }, { "class": "B", "confidence": 0.9 }, ], }, } record = TextClassificationRecord(**data) assert record.predicted is None record.annotation = TextClassificationAnnotation( **{ "agent": "test", "labels": [ { "class": "A", "confidence": 1 }, { "class": "B", "confidence": 1 }, ], }, ) assert record.predicted == PredictionStatus.KO record.prediction = TextClassificationAnnotation( **{ "agent": "test", "labels": [ { "class": "A", "confidence": 0.9 }, { "class": "B", "confidence": 0.9 }, ], }, ) assert record.predicted == PredictionStatus.OK record.prediction = None assert record.predicted is None