コード例 #1
0
ファイル: test_client.py プロジェクト: recognai/rubrix
def test_log_with_annotation(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "test_log_with_annotation"
    client.delete(f"/api/datasets/{dataset_name}")
    rubrix.log(
        TextClassificationRecord(
            id=0,
            inputs={"text": "The text data"},
            annotation="T",
            annotation_agent="test",
        ),
        name=dataset_name,
    )

    df = rubrix.load(dataset_name)
    records = df.to_dict(orient="records")
    assert len(records) == 1
    assert records[0]["status"] == "Validated"

    rubrix.log(
        TextClassificationRecord(
            id=0,
            inputs={"text": "The text data"},
            annotation="T",
            annotation_agent="test",
            status="Discarded",
        ),
        name=dataset_name,
    )
    df = rubrix.load(dataset_name)
    records = df.to_dict(orient="records")
    assert len(records) == 1
    assert records[0]["status"] == "Discarded"
コード例 #2
0
ファイル: test_log.py プロジェクト: javispp/rubrix
def test_text_classification(mock_response_200, mock_response_text):
    """Testing text classification with log function

    It checks a Response is generated.

    Parameters
    ----------
    mock_response_200
        Mocked correct http response, emulating API init
    mock_response_text
        Mocked response given by the sync method, emulating the log of data
    """
    records = [
        TextClassificationRecord(
            inputs={"review_body": "increible test"},
            prediction=[("test", 0.9), ("test2", 0.1)],
            annotation="test",
            metadata={"product_category": "test de pytest"},
            id="test",
        )
    ]

    assert (rubrix.log(
        name="test",
        records=records,
        tags={
            "type": "sentiment classifier",
            "lang": "spanish"
        },
    ) == BulkResponse(dataset="test", processed=500, failed=0))
コード例 #3
0
ファイル: test_client.py プロジェクト: recognai/rubrix
def test_snapshots(monkeypatch):
    mocking_client(monkeypatch)
    dataset = "test_create_dataset_snapshot"
    client.delete(f"/api/datasets/{dataset}")
    sleep(1)
    api_ds_prefix = f"/api/datasets/{dataset}"

    expected_data = 100
    create_some_data_for_text_classification(dataset, n=expected_data)
    response = client.post(f"{api_ds_prefix}/snapshots")
    assert response.status_code == 200
    snapshots = rubrix.snapshots(dataset)
    assert len(snapshots) > 0
    for snapshot in snapshots:
        assert snapshot.task == TaskType.text_classification
        assert snapshot.id
        assert snapshot.creation_date

    ds = rubrix.load(name=dataset)
    assert isinstance(ds, pandas.DataFrame)
    assert len(ds) == expected_data
    records = list(
        map(lambda r: TextClassificationRecord(**r),
            ds.to_dict(orient="records")))

    ds = rubrix.load(name=dataset, snapshot=snapshots[0].id)
    assert isinstance(ds, pandas.DataFrame)
コード例 #4
0
ファイル: asgi.py プロジェクト: recognai/rubrix
def text_classification_mapper(inputs, outputs):
    return TextClassificationRecord(
        inputs=inputs,
        prediction=[(label, score) for label, score in zip(
            outputs.get("labels", []), outputs.get("probabilities", []))],
        event_timestamp=datetime.datetime.now(),
    )
コード例 #5
0
ファイル: test_client.py プロジェクト: recognai/rubrix
def test_single_record(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "test_log_single_records"
    client.delete(f"/api/datasets/{dataset_name}")
    item = TextClassificationRecord(
        inputs={"text": "This is a single record. Only this. No more."})

    rubrix.log(item, name=dataset_name)
コード例 #6
0
ファイル: test_client.py プロジェクト: recognai/rubrix
def test_log_records_with_too_long_text(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "test_log_records_with_too_long_text"
    client.delete(f"/api/datasets/{dataset_name}")
    item = TextClassificationRecord(
        inputs={"text": "This is a toooooo long text\n" * 10000})

    rubrix.log([item], name=dataset_name)
コード例 #7
0
ファイル: test_client.py プロジェクト: recognai/rubrix
def test_create_ds_with_wrong_name(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "Test Create_ds_with_wrong_name"
    client.delete(f"/api/datasets/{dataset_name}")

    with pytest.raises(
            Exception,
            match="msg='string does not match regex",
    ):
        rubrix.log(
            TextClassificationRecord(inputs={"text": "The text data"}, ),
            name=dataset_name,
        )
コード例 #8
0
ファイル: test_log.py プロジェクト: javispp/rubrix
def test_wrong_response(mock_response_200, mock_wrong_bulk_response):
    rubrix._client = None
    with pytest.raises(
            Exception,
            match=
            "Connection error: API is not responding. The API answered with",
    ):
        rubrix.log(
            name="dataset",
            records=[
                TextClassificationRecord(inputs={"text": "The textual info"})
            ],
            tags={"env": "Test"},
        )
コード例 #9
0
ファイル: test_client.py プロジェクト: recognai/rubrix
def test_text_classifier_with_inputs_list(monkeypatch):
    mocking_client(monkeypatch)
    dataset = "test_text_classifier_with_inputs_list"
    client.delete(f"/api/datasets/{dataset}")

    expected_inputs = ["A", "List", "of", "values"]
    rubrix.log(
        TextClassificationRecord(
            id=0,
            inputs=expected_inputs,
            annotation_agent="test",
            annotation=["T"],
        ),
        name=dataset,
    )

    df = rubrix.load(name=dataset)
    records = df.to_dict(orient="records")
    assert len(records) == 1
    assert records[0]["inputs"]["text"] == expected_inputs
コード例 #10
0
ファイル: test_client.py プロジェクト: recognai/rubrix
def test_delete_dataset(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "test_delete_dataset"
    client.delete(f"/api/datasets/{dataset_name}")

    rubrix.log(
        TextClassificationRecord(
            id=0,
            inputs={"text": "The text data"},
            annotation_agent="test",
            annotation=["T"],
        ),
        name=dataset_name,
    )
    rubrix.load(name=dataset_name)
    rubrix.delete(name=dataset_name)
    sleep(1)
    with pytest.raises(
            Exception,
            match="Not found error. The API answered with a 404 code"):
        rubrix.load(name=dataset_name)
コード例 #11
0
ファイル: test_log.py プロジェクト: javispp/rubrix
def test_info_message(mock_response_200, mock_response_text, caplog):
    """Testing initialization info message

    Parameters
    ----------
    mock_response_200
        Mocked correct http response, emulating API init
    mock_response_text
        Mocked response given by the sync method, emulating the log of data
    caplog
        Captures the logging output
    """

    rubrix._client = None  # Force client initialization
    caplog.set_level(logging.INFO)

    records = [
        TextClassificationRecord(
            inputs={"review_body": "increible test"},
            prediction=[("test", 0.9), ("test2", 0.1)],
            annotation="test",
            metadata={"product_category": "test de pytest"},
            id="test",
        )
    ]

    rubrix.log(
        name="test",
        records=records,
        tags={
            "type": "sentiment classifier",
            "lang": "spanish"
        },
    )

    print(caplog.text)

    assert "Rubrix has been initialized on http://localhost:6900" in caplog.text
コード例 #12
0
ファイル: test_client.py プロジェクト: recognai/rubrix
def test_text_classification_record_to_sdk(annotation):
    token_attributions = [
        TokenAttributions(token="test",
                          attributions={
                              "label1": 1.0,
                              "label2": 2.0
                          })
    ]
    record = TextClassificationRecord(
        inputs={"text": "test"},
        prediction=[("label1", 0.5), ("label2", 0.5)],
        annotation=annotation,
        prediction_agent="test_model",
        annotation_agent="test_annotator",
        multi_label=True,
        explanation={"text": token_attributions},
        id=1,
        metadata={"metadata": "test"},
        status="Default",
        event_timestamp=datetime.datetime(2000, 1, 1),
    )
    sdk_record = RubrixClient._text_classification_record_to_sdk(record)

    assert sdk_record.event_timestamp == datetime.datetime(2000, 1, 1)
コード例 #13
0
ファイル: test_client.py プロジェクト: recognai/rubrix
 def generator(items: int = 10) -> Iterable[TextClassificationRecord]:
     for i in range(0, items):
         yield TextClassificationRecord(id=i,
                                        inputs={"text": "The text data"})