コード例 #1
0
def test_create_records_for_text_classification():
    dataset = "test_create_records_for_text_classification"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200
    tags = {"env": "test", "class": "text classification"}
    metadata = {"config": {"the": "config"}}
    classification_bulk = TextClassificationBulkData(
        tags=tags,
        metadata=metadata,
        records=[
            TextClassificationRecord(
                **{
                    "id": 0,
                    "inputs": {
                        "data": "my data"
                    },
                    "prediction": {
                        "agent":
                        "test",
                        "labels": [
                            {
                                "class": "Test",
                                "confidence": 0.3
                            },
                            {
                                "class": "Mocking",
                                "confidence": 0.7
                            },
                        ],
                    },
                })
        ],
    )
    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        json=classification_bulk.dict(by_alias=True),
    )

    assert response.status_code == 200
    bulk_response = BulkResponse.parse_obj(response.json())
    assert bulk_response.dataset == dataset
    assert bulk_response.failed == 0
    assert bulk_response.processed == 1

    response = client.get(f"/api/datasets/{dataset}")
    assert response.status_code == 200
    created_dataset = Dataset.parse_obj(response.json())
    assert created_dataset.tags == tags
    assert created_dataset.metadata == metadata

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:search", json={})

    assert response.status_code == 200
    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.total == 1
    assert results.aggregations.predicted_as == {"Mocking": 1}
    assert results.aggregations.status == {"Default": 1}
    assert results.aggregations.score
    assert results.aggregations.predicted == {}
コード例 #2
0
def test_include_event_timestamp():
    dataset = "test_include_event_timestamp"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        data=TextClassificationBulkData(
            tags={
                "env": "test",
                "class": "text classification"
            },
            metadata={
                "config": {
                    "the": "config"
                }
            },
            records=[
                TextClassificationRecord(
                    **{
                        "id": i,
                        "inputs": {
                            "data": "my data"
                        },
                        "event_timestamp": datetime.utcnow(),
                        "prediction": {
                            "agent":
                            "test",
                            "labels": [
                                {
                                    "class": "Test",
                                    "confidence": 0.3
                                },
                                {
                                    "class": "Mocking",
                                    "confidence": 0.7
                                },
                            ],
                        },
                    }) for i in range(0, 100)
            ],
        ).json(by_alias=True),
    )
    bulk_response = BulkResponse.parse_obj(response.json())
    assert bulk_response.processed == 100

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:search?from=10",
        json={},
    )

    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.total == 100
    assert all(
        map(lambda record: record.event_timestamp is not None,
            results.records))
コード例 #3
0
def test_disable_aggregations_when_scroll():
    dataset = "test_disable_aggregations_when_scroll"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        json=TextClassificationBulkData(
            tags={
                "env": "test",
                "class": "text classification"
            },
            metadata={
                "config": {
                    "the": "config"
                }
            },
            records=[
                TextClassificationRecord(
                    **{
                        "id": i,
                        "inputs": {
                            "data": "my data"
                        },
                        "prediction": {
                            "agent":
                            "test",
                            "labels": [
                                {
                                    "class": "Test",
                                    "confidence": 0.3
                                },
                                {
                                    "class": "Mocking",
                                    "confidence": 0.7
                                },
                            ],
                        },
                    }) for i in range(0, 100)
            ],
        ).dict(by_alias=True),
    )
    bulk_response = BulkResponse.parse_obj(response.json())
    assert bulk_response.processed == 100

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:search?from=10",
        json={},
    )

    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.total == 100
    assert results.aggregations is None
コード例 #4
0
ファイル: test_api.py プロジェクト: recognai/rubrix
def test_words_cloud():
    dataset = "test_language_detection"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        data=TextClassificationBulkData(
            records=[
                TextClassificationRecord(
                    **{
                        "id": 0,
                        "inputs": {"text": "Esto es un ejemplo de texto"},
                    }
                ),
                TextClassificationRecord(
                    **{
                        "id": 1,
                        "inputs": {"text": "This is an simple text example"},
                    }
                ),
                TextClassificationRecord(
                    **{
                        "id": 2,
                        "inputs": {"text": "C'est nes pas une pipe"},
                    }
                ),
            ],
        ).json(by_alias=True),
    )
    BulkResponse.parse_obj(response.json())

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:search",
        json={},
    )

    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.aggregations.words is not None
コード例 #5
0
def bulk_records(
        name: str,
        bulk: TokenClassificationBulkData,
        service: TokenClassificationService = Depends(
            token_classification_service),
        datasets: DatasetsService = Depends(create_dataset_service),
        current_user: User = Security(auth.get_user, scopes=[]),
) -> BulkResponse:
    """
    Includes a chunk of record data with provided dataset bulk information

    Parameters
    ----------
    name:
        The dataset name
    bulk:
        The bulk data
    service:
        the Service
    datasets:
        The dataset service
    current_user:
        Current request user

    Returns
    -------
        Bulk response data
    """

    datasets.upsert(
        CreationDatasetRequest(**{
            **bulk.dict(), "name": name
        }),
        owner=current_user.current_group,
        task=TASK_TYPE,
    )
    result = service.add_records(
        dataset=name,
        owner=current_user.current_group,
        records=bulk.records,
    )
    return BulkResponse(
        dataset=name,
        processed=result.processed,
        failed=result.failed,
    )
コード例 #6
0
def test_create_records_for_token_classification():
    dataset = "test_create_records_for_token_classification"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200

    records = [
        TokenClassificationRecord.parse_obj(data) for data in [
            {
                "tokens": "This is a text".split(" "),
                "raw_text": "This is a text",
                "metadata": {
                    "field_one": "value one",
                    "field_two": "value 2"
                },
            },
            {
                "tokens": "This is a text".split(" "),
                "raw_text": "This is a text",
                "metadata": {
                    "field_one": "value one",
                    "field_two": "value 2"
                },
            },
        ]
    ]
    response = client.post(
        f"/api/datasets/{dataset}/TokenClassification:bulk",
        json=TokenClassificationBulkData(
            tags={
                "env": "test",
                "class": "text classification"
            },
            metadata={
                "config": {
                    "the": "config"
                }
            },
            records=records,
        ).dict(by_alias=True),
    )

    assert response.status_code == 200, response.json()
    bulk_response = BulkResponse.parse_obj(response.json())
    assert bulk_response.dataset == dataset
    assert bulk_response.failed == 0
    assert bulk_response.processed == 2
コード例 #7
0
ファイル: service.py プロジェクト: recognai/rubrix
    def add_records(
        self,
        dataset: str,
        owner: Optional[str],
        records: List[CreationTokenClassificationRecord],
    ):
        dataset = self.__datasets__.find_by_name(dataset, owner=owner)

        db_records = []
        now = datetime.datetime.now()
        for record in records:
            db_record = TokenClassificationRecord.parse_obj(record)
            db_record.last_updated = now
            db_records.append(db_record.dict(exclude_none=True))

        failed = self.__dao__.add_records(
            dataset=dataset,
            records=db_records,
        )
        return BulkResponse(dataset=dataset.name, processed=len(records), failed=failed)
コード例 #8
0
ファイル: test_api.py プロジェクト: recognai/rubrix
def test_create_records_for_text_classification_with_multi_label():
    dataset = "test_create_records_for_text_classification_with_multi_label"
    assert client.delete(f"/api/datasets/{dataset}").status_code == 200

    records = [
        TextClassificationRecord.parse_obj(data)
        for data in [
            {
                "id": 0,
                "inputs": {"data": "my data"},
                "multi_label": True,
                "metadata": {"field_one": "value one", "field_two": "value 2"},
                "prediction": {
                    "agent": "test",
                    "labels": [
                        {"class": "Test", "confidence": 0.6},
                        {"class": "Mocking", "confidence": 0.7},
                        {"class": "NoClass", "confidence": 0.2},
                    ],
                },
            },
            {
                "id": 1,
                "inputs": {"data": "my data"},
                "multi_label": True,
                "metadata": {"field_one": "another value one", "field_two": "value 2"},
                "prediction": {
                    "agent": "test",
                    "labels": [
                        {"class": "Test", "confidence": 0.6},
                        {"class": "Mocking", "confidence": 0.7},
                        {"class": "NoClass", "confidence": 0.2},
                    ],
                },
            },
        ]
    ]
    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        json=TextClassificationBulkData(
            tags={"env": "test", "class": "text classification"},
            metadata={"config": {"the": "config"}},
            records=records,
        ).dict(by_alias=True),
    )

    assert response.status_code == 200, response.json()
    bulk_response = BulkResponse.parse_obj(response.json())
    assert bulk_response.dataset == dataset
    assert bulk_response.failed == 0
    assert bulk_response.processed == 2

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:bulk",
        json=TextClassificationBulkData(
            tags={"new": "tag"},
            metadata={"new": {"metadata": "value"}},
            records=records,
        ).dict(by_alias=True),
    )

    get_dataset = Dataset.parse_obj(client.get(f"/api/datasets/{dataset}").json())
    assert get_dataset.tags == {
        "env": "test",
        "class": "text classification",
        "new": "tag",
    }
    assert get_dataset.metadata == {
        "config": {"the": "config"},
        "new": {"metadata": "value"},
    }

    assert response.status_code == 200, response.json()

    response = client.post(
        f"/api/datasets/{dataset}/TextClassification:search", json={}
    )

    assert response.status_code == 200
    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.total == 2
    assert results.aggregations.predicted_as == {"Mocking": 2, "Test": 2}
    assert results.records[0].predicted is None
コード例 #9
0
ファイル: test_api.py プロジェクト: recognai/rubrix
def test_partial_record_update():
    name = "test_partial_record_update"
    assert client.delete(f"/api/datasets/{name}").status_code == 200

    record = TextClassificationRecord(
        **{
            "id": 1,
            "inputs": {"text": "This is a text, oh yeah!"},
            "prediction": {
                "agent": "test",
                "labels": [
                    {"class": "Positive", "confidence": 0.6},
                    {"class": "Negative", "confidence": 0.3},
                    {"class": "Other", "confidence": 0.1},
                ],
            },
        }
    )

    bulk = TextClassificationBulkData(
        records=[record],
    )

    response = client.post(
        f"/api/datasets/{name}/TextClassification:bulk",
        json=bulk.dict(by_alias=True),
    )

    assert response.status_code == 200
    bulk_response = BulkResponse.parse_obj(response.json())
    assert bulk_response.failed == 0
    assert bulk_response.processed == 1

    record.annotation = TextClassificationAnnotation.parse_obj(
        {
            "agent": "gold_standard",
            "labels": [{"class": "Positive"}],
        }
    )

    bulk.records = [record]

    client.post(
        f"/api/datasets/{name}/TextClassification:bulk",
        json=bulk.dict(by_alias=True),
    )

    response = client.post(
        f"/api/datasets/{name}/TextClassification:search",
        json={
            "query": TextClassificationQuery(predicted=PredictionStatus.OK).dict(
                by_alias=True
            ),
        },
    )

    assert response.status_code == 200
    results = TextClassificationSearchResults.parse_obj(response.json())
    assert results.total == 1
    first_record = results.records[0]
    assert first_record.last_updated is not None
    first_record.last_updated = None
    assert TextClassificationRecord(
        **first_record.dict(by_alias=True, exclude_none=True)
    ) == TextClassificationRecord(
        **{
            "id": 1,
            "inputs": {"text": "This is a text, oh yeah!"},
            "prediction": {
                "agent": "test",
                "labels": [
                    {"class": "Positive", "confidence": 0.6},
                    {"class": "Negative", "confidence": 0.3},
                    {"class": "Other", "confidence": 0.1},
                ],
            },
            "annotation": {
                "agent": "gold_standard",
                "labels": [{"class": "Positive"}],
            },
        }
    )