Ejemplo n.º 1
0
def test_prediction_ok_cases():

    data = {
        "multi_label": True,
        "inputs": {
            "data": "My cool data"
        },
        "prediction": {
            "agent":
            "test",
            "labels": [
                {
                    "class": "A",
                    "confidence": 0.3
                },
                {
                    "class": "B",
                    "confidence": 0.9
                },
            ],
        },
    }

    record = TextClassificationRecord(**data)
    assert record.predicted is None
    record.annotation = TextClassificationAnnotation(
        **{
            "agent":
            "test",
            "labels": [
                {
                    "class": "A",
                    "confidence": 1
                },
                {
                    "class": "B",
                    "confidence": 1
                },
            ],
        }, )
    assert record.predicted == PredictionStatus.KO

    record.prediction = TextClassificationAnnotation(
        **{
            "agent":
            "test",
            "labels": [
                {
                    "class": "A",
                    "confidence": 0.9
                },
                {
                    "class": "B",
                    "confidence": 0.9
                },
            ],
        }, )
    assert record.predicted == PredictionStatus.OK

    record.prediction = None
    assert record.predicted is None
Ejemplo n.º 2
0
def test_partial_record_update():
    name = "test_partial_record_update"
    assert client.delete(f"/api/datasets/{name}").status_code == 200

    record = TextClassificationRecord(
        **{
            "id": 1,
            "inputs": {"text": "This is a text, oh yeah!"},
            "prediction": {
                "agent": "test",
                "labels": [
                    {"class": "Positive", "confidence": 0.6},
                    {"class": "Negative", "confidence": 0.3},
                    {"class": "Other", "confidence": 0.1},
                ],
            },
        }
    )

    bulk = TextClassificationBulkData(
        records=[record],
    )

    response = client.post(
        f"/api/datasets/{name}/TextClassification:bulk",
        json=bulk.dict(by_alias=True),
    )

    assert response.status_code == 200
    bulk_response = BulkResponse.parse_obj(response.json())
    assert bulk_response.failed == 0
    assert bulk_response.processed == 1

    record.annotation = TextClassificationAnnotation.parse_obj(
        {
            "agent": "gold_standard",
            "labels": [{"class": "Positive"}],
        }
    )

    bulk.records = [record]

    client.post(
        f"/api/datasets/{name}/TextClassification:bulk",
        json=bulk.dict(by_alias=True),
    )

    response = client.post(
        f"/api/datasets/{name}/TextClassification:search",
        json={
            "query": TextClassificationQuery(predicted=PredictionStatus.OK).dict(
                by_alias=True
            ),
        },
    )

    assert response.status_code == 200
    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.total == 1
    first_record = results.records[0]
    assert first_record.last_updated is not None
    first_record.last_updated = None
    assert TextClassificationRecord(
        **first_record.dict(by_alias=True, exclude_none=True)
    ) == TextClassificationRecord(
        **{
            "id": 1,
            "inputs": {"text": "This is a text, oh yeah!"},
            "prediction": {
                "agent": "test",
                "labels": [
                    {"class": "Positive", "confidence": 0.6},
                    {"class": "Negative", "confidence": 0.3},
                    {"class": "Other", "confidence": 0.1},
                ],
            },
            "annotation": {
                "agent": "gold_standard",
                "labels": [{"class": "Positive"}],
            },
        }
    )