Exemplo n.º 1
0
def test_metadata_with_point_in_field_name():
    dataset = "test_metadata_with_point_in_field_name"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        data=TextClassificationBulkData(
            records=[
                TextClassificationRecord(
                    **{
                        "id": 0,
                        "inputs": {"text": "Esto es un ejemplo de texto"},
                        "metadata": {"field.one": 1, "field.two": 2},
                    }
                ),
                TextClassificationRecord(
                    **{
                        "id": 1,
                        "inputs": {"text": "This is an simple text example"},
                        "metadata": {"field.one": 1, "field.two": 2},
                    }
                ),
            ],
        ).json(by_alias=True),
    )

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:search?limit=0",
        json={},
    )

    results = TextClassificationSearchResults.parse_obj(response.json())
    assert "field.one" in results.aggregations.metadata
    assert results.aggregations.metadata.get("field.one", {})["1"] == 2
    assert results.aggregations.metadata.get("field.two", {})["2"] == 2
Exemplo n.º 2
0
def test_confidence_integrity():
    data = {
        "multi_label": False,
        "inputs": {
            "data": "My cool data"
        },
        "prediction": {
            "agent":
            "test",
            "labels": [
                {
                    "class": "A",
                    "confidence": 0.3
                },
                {
                    "class": "B",
                    "confidence": 0.9
                },
            ],
        },
    }

    try:
        TextClassificationRecord.parse_obj(data)
    except ValidationError as e:
        assert "Wrong score distributions" in e.json()

    data["multi_label"] = True
    record = TextClassificationRecord.parse_obj(data)
    assert record is not None

    data["multi_label"] = False
    data["prediction"]["labels"] = [
        {
            "class": "B",
            "confidence": 0.9
        },
    ]
    record = TextClassificationRecord.parse_obj(data)
    assert record is not None

    data["prediction"]["labels"] = [
        {
            "class": "B",
            "confidence": 0.10000000012
        },
        {
            "class": "B",
            "confidence": 0.90000000002
        },
    ]
    record = TextClassificationRecord.parse_obj(data)
    assert record is not None
Exemplo n.º 3
0
def test_too_long_label():
    with pytest.raises(ValidationError, match="exceeds max length"):
        TextClassificationRecord.parse_obj({
            "inputs": {
                "text": "bogh"
            },
            "prediction": {
                "agent": "test",
                "labels": [{
                    "class": "a" * 1000
                }],
            },
        })
Exemplo n.º 4
0
def test_wrong_text_query():
    dataset = "test_wrong_text_query"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        data=TextClassificationBulkData(
            records=[
                TextClassificationRecord(
                    **{
                        "id": 0,
                        "inputs": {"text": "Esto es un ejemplo de texto"},
                        "metadata": {"field.one": 1, "field.two": 2},
                    }
                ),
            ],
        ).json(by_alias=True),
    )

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:search",
        json=TextClassificationSearchRequest(
            query=TextClassificationQuery(query_text="!")
        ).dict(),
    )

    assert response.status_code == 400
    assert response.json()["detail"] == "Failed to parse query [!]"
Exemplo n.º 5
0
def test_create_records_for_text_classification():
    dataset = "test_create_records_for_text_classification"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200
    tags = {"env": "test", "class": "text classification"}
    metadata = {"config": {"the": "config"}}
    classification_bulk = TextClassificationBulkData(
        tags=tags,
        metadata=metadata,
        records=[
            TextClassificationRecord(
                **{
                    "id": 0,
                    "inputs": {
                        "data": "my data"
                    },
                    "prediction": {
                        "agent":
                        "test",
                        "labels": [
                            {
                                "class": "Test",
                                "confidence": 0.3
                            },
                            {
                                "class": "Mocking",
                                "confidence": 0.7
                            },
                        ],
                    },
                })
        ],
    )
    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        json=classification_bulk.dict(by_alias=True),
    )

    assert response.status_code == 200
    bulk_response = BulkResponse.parse_obj(response.json())
    assert bulk_response.dataset == dataset
    assert bulk_response.failed == 0
    assert bulk_response.processed == 1

    response = client.get(f"/api/datasets/{dataset}")
    assert response.status_code == 200
    created_dataset = Dataset.parse_obj(response.json())
    assert created_dataset.tags == tags
    assert created_dataset.metadata == metadata

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:search", json={})

    assert response.status_code == 200
    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.total == 1
    assert results.aggregations.predicted_as == {"Mocking": 1}
    assert results.aggregations.status == {"Default": 1}
    assert results.aggregations.score
    assert results.aggregations.predicted == {}
Exemplo n.º 6
0
def test_created_record_with_default_status():
    data = {
        "inputs": {
            "data": "My cool data"
        },
    }

    record = TextClassificationRecord.parse_obj(data)
    assert record.status == TaskStatus.default
Exemplo n.º 7
0
def test_include_event_timestamp():
    dataset = "test_include_event_timestamp"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        data=TextClassificationBulkData(
            tags={
                "env": "test",
                "class": "text classification"
            },
            metadata={
                "config": {
                    "the": "config"
                }
            },
            records=[
                TextClassificationRecord(
                    **{
                        "id": i,
                        "inputs": {
                            "data": "my data"
                        },
                        "event_timestamp": datetime.utcnow(),
                        "prediction": {
                            "agent":
                            "test",
                            "labels": [
                                {
                                    "class": "Test",
                                    "confidence": 0.3
                                },
                                {
                                    "class": "Mocking",
                                    "confidence": 0.7
                                },
                            ],
                        },
                    }) for i in range(0, 100)
            ],
        ).json(by_alias=True),
    )
    bulk_response = BulkResponse.parse_obj(response.json())
    assert bulk_response.processed == 100

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:search?from=10",
        json={},
    )

    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.total == 100
    assert all(
        map(lambda record: record.event_timestamp is not None,
            results.records))
Exemplo n.º 8
0
def test_flatten_inputs():
    data = {
        "inputs": {
            "mail": {
                "subject": "The mail subject",
                "body": "This is a large text body"
            }
        }
    }
    record = TextClassificationRecord.parse_obj(data)
    assert list(record.inputs.keys()) == ["mail.subject", "mail.body"]
Exemplo n.º 9
0
def test_too_long_metadata():
    record = TextClassificationRecord.parse_obj({
        "inputs": {
            "text": "bogh"
        },
        "metadata": {
            "too_long": "a" * 1000
        },
    })

    assert len(record.metadata["too_long"]) == MAX_KEYWORD_LENGTH
Exemplo n.º 10
0
def test_predicted_as_with_no_labels():
    data = {
        "inputs": {
            "text": "The input text"
        },
        "prediction": {
            "agent": "test",
            "labels": []
        },
    }
    record = TextClassificationRecord(**data)
    assert record.predicted_as == []
Exemplo n.º 11
0
def test_disable_aggregations_when_scroll():
    dataset = "test_disable_aggregations_when_scroll"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        json=TextClassificationBulkData(
            tags={
                "env": "test",
                "class": "text classification"
            },
            metadata={
                "config": {
                    "the": "config"
                }
            },
            records=[
                TextClassificationRecord(
                    **{
                        "id": i,
                        "inputs": {
                            "data": "my data"
                        },
                        "prediction": {
                            "agent":
                            "test",
                            "labels": [
                                {
                                    "class": "Test",
                                    "confidence": 0.3
                                },
                                {
                                    "class": "Mocking",
                                    "confidence": 0.7
                                },
                            ],
                        },
                    }) for i in range(0, 100)
            ],
        ).dict(by_alias=True),
    )
    bulk_response = BulkResponse.parse_obj(response.json())
    assert bulk_response.processed == 100

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:search?from=10",
        json={},
    )

    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.total == 100
    assert results.aggregations is None
Exemplo n.º 12
0
def test_words_cloud():
    dataset = "test_language_detection"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        data=TextClassificationBulkData(
            records=[
                TextClassificationRecord(
                    **{
                        "id": 0,
                        "inputs": {"text": "Esto es un ejemplo de texto"},
                    }
                ),
                TextClassificationRecord(
                    **{
                        "id": 1,
                        "inputs": {"text": "This is an simple text example"},
                    }
                ),
                TextClassificationRecord(
                    **{
                        "id": 2,
                        "inputs": {"text": "C'est nes pas une pipe"},
                    }
                ),
            ],
        ).json(by_alias=True),
    )
    BulkResponse.parse_obj(response.json())

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:search",
        json={},
    )

    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.aggregations.words is not None
Exemplo n.º 13
0
def create_some_data_for_text_classification(name: str, n: int):
    records = [
        TextClassificationRecord(**data)
        for idx in range(0, n or 10, 2)
        for data in [
            {
                "id": idx,
                "inputs": {"data": "my data"},
                "multi_label": True,
                "metadata": {"field_one": "value one", "field_two": "value 2"},
                "status": TaskStatus.validated,
                "annotation": {
                    "agent": "test",
                    "labels": [
                        {"class": "Test"},
                        {"class": "Mocking"},
                    ],
                },
            },
            {
                "id": idx + 1,
                "inputs": {"data": "my data"},
                "multi_label": True,
                "metadata": {"field_one": "another value one", "field_two": "value 2"},
                "status": TaskStatus.validated,
                "prediction": {
                    "agent": "test",
                    "labels": [
                        {"class": "NoClass"},
                    ],
                },
                "annotation": {
                    "agent": "test",
                    "labels": [
                        {"class": "Test"},
                    ],
                },
            },
        ]
    ]
    client.post(
        f"/api/datasets/{name}/{TaskType.text_classification}:bulk",
        json=TextClassificationBulkData(
            tags={"env": "test", "class": "text classification"},
            metadata={"config": {"the": "config"}},
            records=records,
        ).dict(by_alias=True),
    )
Exemplo n.º 14
0
def test_predicted_ok_for_multilabel_unordered():
    record = TextClassificationRecord(
        inputs={"text": "The text"},
        prediction=TextClassificationAnnotation(
            agent="test",
            labels=[
                ClassPrediction(class_label="B"),
                ClassPrediction(class_label="C", score=0.3),
                ClassPrediction(class_label="A"),
            ],
        ),
        annotation=TextClassificationAnnotation(
            agent="test",
            labels=[
                ClassPrediction(class_label="A"),
                ClassPrediction(class_label="B")
            ],
        ),
        multi_label=True,
    )

    assert record.predicted == PredictionStatus.OK
Exemplo n.º 15
0
def test_sort_by_id_as_default():
    dataset = "test_sort_by_id_as_default"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200
    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        json=TextClassificationBulkData(
            records=[
                TextClassificationRecord(
                    **{
                        "id": i,
                        "inputs": {"data": "my data"},
                        "metadata": {"s": "value"},
                    }
                )
                for i in range(0, 100)
            ],
        ).dict(by_alias=True),
    )
    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:search?from=0&limit=10",
        json={},
    )

    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.total == 100
    assert list(map(lambda r: r.id, results.records)) == [
        0,
        1,
        10,
        11,
        12,
        13,
        14,
        15,
        16,
        17,
    ]
Exemplo n.º 16
0
def test_partial_record_update():
    name = "test_partial_record_update"
    assert client.delete(f"/api/datasets/{name}").status_code == 200

    record = TextClassificationRecord(
        **{
            "id": 1,
            "inputs": {"text": "This is a text, oh yeah!"},
            "prediction": {
                "agent": "test",
                "labels": [
                    {"class": "Positive", "confidence": 0.6},
                    {"class": "Negative", "confidence": 0.3},
                    {"class": "Other", "confidence": 0.1},
                ],
            },
        }
    )

    bulk = TextClassificationBulkData(
        records=[record],
    )

    response = client.post(
        f"/api/datasets/{name}/TextClassification:bulk",
        json=bulk.dict(by_alias=True),
    )

    assert response.status_code == 200
    bulk_response = BulkResponse.parse_obj(response.json())
    assert bulk_response.failed == 0
    assert bulk_response.processed == 1

    record.annotation = TextClassificationAnnotation.parse_obj(
        {
            "agent": "gold_standard",
            "labels": [{"class": "Positive"}],
        }
    )

    bulk.records = [record]

    client.post(
        f"/api/datasets/{name}/TextClassification:bulk",
        json=bulk.dict(by_alias=True),
    )

    response = client.post(
        f"/api/datasets/{name}/TextClassification:search",
        json={
            "query": TextClassificationQuery(predicted=PredictionStatus.OK).dict(
                by_alias=True
            ),
        },
    )

    assert response.status_code == 200
    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.total == 1
    first_record = results.records[0]
    assert first_record.last_updated is not None
    first_record.last_updated = None
    assert TextClassificationRecord(
        **first_record.dict(by_alias=True, exclude_none=True)
    ) == TextClassificationRecord(
        **{
            "id": 1,
            "inputs": {"text": "This is a text, oh yeah!"},
            "prediction": {
                "agent": "test",
                "labels": [
                    {"class": "Positive", "confidence": 0.6},
                    {"class": "Negative", "confidence": 0.3},
                    {"class": "Other", "confidence": 0.1},
                ],
            },
            "annotation": {
                "agent": "gold_standard",
                "labels": [{"class": "Positive"}],
            },
        }
    )
Exemplo n.º 17
0
def test_create_records_for_text_classification_with_multi_label():
    dataset = "test_create_records_for_text_classification_with_multi_label"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200

    records = [
        TextClassificationRecord.parse_obj(data)
        for data in [
            {
                "id": 0,
                "inputs": {"data": "my data"},
                "multi_label": True,
                "metadata": {"field_one": "value one", "field_two": "value 2"},
                "prediction": {
                    "agent": "test",
                    "labels": [
                        {"class": "Test", "confidence": 0.6},
                        {"class": "Mocking", "confidence": 0.7},
                        {"class": "NoClass", "confidence": 0.2},
                    ],
                },
            },
            {
                "id": 1,
                "inputs": {"data": "my data"},
                "multi_label": True,
                "metadata": {"field_one": "another value one", "field_two": "value 2"},
                "prediction": {
                    "agent": "test",
                    "labels": [
                        {"class": "Test", "confidence": 0.6},
                        {"class": "Mocking", "confidence": 0.7},
                        {"class": "NoClass", "confidence": 0.2},
                    ],
                },
            },
        ]
    ]
    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        json=TextClassificationBulkData(
            tags={"env": "test", "class": "text classification"},
            metadata={"config": {"the": "config"}},
            records=records,
        ).dict(by_alias=True),
    )

    assert response.status_code == 200, response.json()
    bulk_response = BulkResponse.parse_obj(response.json())
    assert bulk_response.dataset == dataset
    assert bulk_response.failed == 0
    assert bulk_response.processed == 2

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        json=TextClassificationBulkData(
            tags={"new": "tag"},
            metadata={"new": {"metadata": "value"}},
            records=records,
        ).dict(by_alias=True),
    )

    get_dataset = Dataset.parse_obj(client.get(f"/api/datasets/{dataset}").json())
    assert get_dataset.tags == {
        "env": "test",
        "class": "text classification",
        "new": "tag",
    }
    assert get_dataset.metadata == {
        "config": {"the": "config"},
        "new": {"metadata": "value"},
    }

    assert response.status_code == 200, response.json()

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:search", json={}
    )

    assert response.status_code == 200
    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.total == 2
    assert results.aggregations.predicted_as == {"Mocking": 2, "Test": 2}
    assert results.records[0].predicted is None
Exemplo n.º 18
0
def test_prediction_ok_cases():

    data = {
        "multi_label": True,
        "inputs": {
            "data": "My cool data"
        },
        "prediction": {
            "agent":
            "test",
            "labels": [
                {
                    "class": "A",
                    "confidence": 0.3
                },
                {
                    "class": "B",
                    "confidence": 0.9
                },
            ],
        },
    }

    record = TextClassificationRecord(**data)
    assert record.predicted is None
    record.annotation = TextClassificationAnnotation(
        **{
            "agent":
            "test",
            "labels": [
                {
                    "class": "A",
                    "confidence": 1
                },
                {
                    "class": "B",
                    "confidence": 1
                },
            ],
        }, )
    assert record.predicted == PredictionStatus.KO

    record.prediction = TextClassificationAnnotation(
        **{
            "agent":
            "test",
            "labels": [
                {
                    "class": "A",
                    "confidence": 0.9
                },
                {
                    "class": "B",
                    "confidence": 0.9
                },
            ],
        }, )
    assert record.predicted == PredictionStatus.OK

    record.prediction = None
    assert record.predicted is None