def test_confidence_integrity(): data = { "multi_label": False, "inputs": { "data": "My cool data" }, "prediction": { "agent": "test", "labels": [ { "class": "A", "confidence": 0.3 }, { "class": "B", "confidence": 0.9 }, ], }, } try: TextClassificationRecord.parse_obj(data) except ValidationError as e: assert "Wrong score distributions" in e.json() data["multi_label"] = True record = TextClassificationRecord.parse_obj(data) assert record is not None data["multi_label"] = False data["prediction"]["labels"] = [ { "class": "B", "confidence": 0.9 }, ] record = TextClassificationRecord.parse_obj(data) assert record is not None data["prediction"]["labels"] = [ { "class": "B", "confidence": 0.10000000012 }, { "class": "B", "confidence": 0.90000000002 }, ] record = TextClassificationRecord.parse_obj(data) assert record is not None
def test_too_long_label(): with pytest.raises(ValidationError, match="exceeds max length"): TextClassificationRecord.parse_obj({ "inputs": { "text": "bogh" }, "prediction": { "agent": "test", "labels": [{ "class": "a" * 1000 }], }, })
def test_created_record_with_default_status(): data = { "inputs": { "data": "My cool data" }, } record = TextClassificationRecord.parse_obj(data) assert record.status == TaskStatus.default
def test_too_long_metadata(): record = TextClassificationRecord.parse_obj({ "inputs": { "text": "bogh" }, "metadata": { "too_long": "a" * 1000 }, }) assert len(record.metadata["too_long"]) == MAX_KEYWORD_LENGTH
def test_flatten_inputs(): data = { "inputs": { "mail": { "subject": "The mail subject", "body": "This is a large text body" } } } record = TextClassificationRecord.parse_obj(data) assert list(record.inputs.keys()) == ["mail.subject", "mail.body"]
def test_create_records_for_text_classification_with_multi_label(): dataset = "test_create_records_for_text_classification_with_multi_label" assert client.delete(f"/api/datasets/{dataset}").status_code == 200 records = [ TextClassificationRecord.parse_obj(data) for data in [ { "id": 0, "inputs": {"data": "my data"}, "multi_label": True, "metadata": {"field_one": "value one", "field_two": "value 2"}, "prediction": { "agent": "test", "labels": [ {"class": "Test", "confidence": 0.6}, {"class": "Mocking", "confidence": 0.7}, {"class": "NoClass", "confidence": 0.2}, ], }, }, { "id": 1, "inputs": {"data": "my data"}, "multi_label": True, "metadata": {"field_one": "another value one", "field_two": "value 2"}, "prediction": { "agent": "test", "labels": [ {"class": "Test", "confidence": 0.6}, {"class": "Mocking", "confidence": 0.7}, {"class": "NoClass", "confidence": 0.2}, ], }, }, ] ] response = client.post( f"/api/datasets/{dataset}/TextClassification:bulk", json=TextClassificationBulkData( tags={"env": "test", "class": "text classification"}, metadata={"config": {"the": "config"}}, records=records, ).dict(by_alias=True), ) assert response.status_code == 200, response.json() bulk_response = BulkResponse.parse_obj(response.json()) assert bulk_response.dataset == dataset assert bulk_response.failed == 0 assert bulk_response.processed == 2 response = client.post( f"/api/datasets/{dataset}/TextClassification:bulk", json=TextClassificationBulkData( tags={"new": "tag"}, metadata={"new": {"metadata": "value"}}, records=records, ).dict(by_alias=True), ) get_dataset = Dataset.parse_obj(client.get(f"/api/datasets/{dataset}").json()) assert get_dataset.tags == { "env": "test", "class": "text classification", "new": "tag", } assert get_dataset.metadata == { "config": {"the": "config"}, "new": {"metadata": "value"}, } assert response.status_code == 200, response.json() response = client.post( f"/api/datasets/{dataset}/TextClassification:search", json={} ) assert response.status_code == 200 results = TextClassificationSearchResults.parse_obj(response.json()) assert results.total == 2 assert results.aggregations.predicted_as == {"Mocking": 2, "Test": 2} assert results.records[0].predicted is None