示例#1
0
def test_confidence_integrity():
    data = {
        "multi_label": False,
        "inputs": {
            "data": "My cool data"
        },
        "prediction": {
            "agent":
            "test",
            "labels": [
                {
                    "class": "A",
                    "confidence": 0.3
                },
                {
                    "class": "B",
                    "confidence": 0.9
                },
            ],
        },
    }

    try:
        TextClassificationRecord.parse_obj(data)
    except ValidationError as e:
        assert "Wrong score distributions" in e.json()

    data["multi_label"] = True
    record = TextClassificationRecord.parse_obj(data)
    assert record is not None

    data["multi_label"] = False
    data["prediction"]["labels"] = [
        {
            "class": "B",
            "confidence": 0.9
        },
    ]
    record = TextClassificationRecord.parse_obj(data)
    assert record is not None

    data["prediction"]["labels"] = [
        {
            "class": "B",
            "confidence": 0.10000000012
        },
        {
            "class": "B",
            "confidence": 0.90000000002
        },
    ]
    record = TextClassificationRecord.parse_obj(data)
    assert record is not None
示例#2
0
def test_too_long_label():
    with pytest.raises(ValidationError, match="exceeds max length"):
        TextClassificationRecord.parse_obj({
            "inputs": {
                "text": "bogh"
            },
            "prediction": {
                "agent": "test",
                "labels": [{
                    "class": "a" * 1000
                }],
            },
        })
示例#3
0
def test_created_record_with_default_status():
    data = {
        "inputs": {
            "data": "My cool data"
        },
    }

    record = TextClassificationRecord.parse_obj(data)
    assert record.status == TaskStatus.default
示例#4
0
def test_too_long_metadata():
    record = TextClassificationRecord.parse_obj({
        "inputs": {
            "text": "bogh"
        },
        "metadata": {
            "too_long": "a" * 1000
        },
    })

    assert len(record.metadata["too_long"]) == MAX_KEYWORD_LENGTH
示例#5
0
def test_flatten_inputs():
    data = {
        "inputs": {
            "mail": {
                "subject": "The mail subject",
                "body": "This is a large text body"
            }
        }
    }
    record = TextClassificationRecord.parse_obj(data)
    assert list(record.inputs.keys()) == ["mail.subject", "mail.body"]
示例#6
0
def test_create_records_for_text_classification_with_multi_label():
    dataset = "test_create_records_for_text_classification_with_multi_label"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200

    records = [
        TextClassificationRecord.parse_obj(data)
        for data in [
            {
                "id": 0,
                "inputs": {"data": "my data"},
                "multi_label": True,
                "metadata": {"field_one": "value one", "field_two": "value 2"},
                "prediction": {
                    "agent": "test",
                    "labels": [
                        {"class": "Test", "confidence": 0.6},
                        {"class": "Mocking", "confidence": 0.7},
                        {"class": "NoClass", "confidence": 0.2},
                    ],
                },
            },
            {
                "id": 1,
                "inputs": {"data": "my data"},
                "multi_label": True,
                "metadata": {"field_one": "another value one", "field_two": "value 2"},
                "prediction": {
                    "agent": "test",
                    "labels": [
                        {"class": "Test", "confidence": 0.6},
                        {"class": "Mocking", "confidence": 0.7},
                        {"class": "NoClass", "confidence": 0.2},
                    ],
                },
            },
        ]
    ]
    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        json=TextClassificationBulkData(
            tags={"env": "test", "class": "text classification"},
            metadata={"config": {"the": "config"}},
            records=records,
        ).dict(by_alias=True),
    )

    assert response.status_code == 200, response.json()
    bulk_response = BulkResponse.parse_obj(response.json())
    assert bulk_response.dataset == dataset
    assert bulk_response.failed == 0
    assert bulk_response.processed == 2

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        json=TextClassificationBulkData(
            tags={"new": "tag"},
            metadata={"new": {"metadata": "value"}},
            records=records,
        ).dict(by_alias=True),
    )

    get_dataset = Dataset.parse_obj(client.get(f"/api/datasets/{dataset}").json())
    assert get_dataset.tags == {
        "env": "test",
        "class": "text classification",
        "new": "tag",
    }
    assert get_dataset.metadata == {
        "config": {"the": "config"},
        "new": {"metadata": "value"},
    }

    assert response.status_code == 200, response.json()

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:search", json={}
    )

    assert response.status_code == 200
    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.total == 2
    assert results.aggregations.predicted_as == {"Mocking": 2, "Test": 2}
    assert results.records[0].predicted is None