示例#1
0
def test_snapshots(monkeypatch):
    mocking_client(monkeypatch)
    dataset = "test_create_dataset_snapshot"
    client.delete(f"/api/datasets/{dataset}")
    sleep(1)
    api_ds_prefix = f"/api/datasets/{dataset}"

    expected_data = 100
    create_some_data_for_text_classification(dataset, n=expected_data)
    response = client.post(f"{api_ds_prefix}/snapshots")
    assert response.status_code == 200
    snapshots = rubrix.snapshots(dataset)
    assert len(snapshots) > 0
    for snapshot in snapshots:
        assert snapshot.task == TaskType.text_classification
        assert snapshot.id
        assert snapshot.creation_date

    ds = rubrix.load(name=dataset)
    assert isinstance(ds, pandas.DataFrame)
    assert len(ds) == expected_data
    records = list(
        map(lambda r: TextClassificationRecord(**r),
            ds.to_dict(orient="records")))

    ds = rubrix.load(name=dataset, snapshot=snapshots[0].id)
    assert isinstance(ds, pandas.DataFrame)
示例#2
0
def test_log_with_annotation(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "test_log_with_annotation"
    client.delete(f"/api/datasets/{dataset_name}")
    rubrix.log(
        TextClassificationRecord(
            id=0,
            inputs={"text": "The text data"},
            annotation="T",
            annotation_agent="test",
        ),
        name=dataset_name,
    )

    df = rubrix.load(dataset_name)
    records = df.to_dict(orient="records")
    assert len(records) == 1
    assert records[0]["status"] == "Validated"

    rubrix.log(
        TextClassificationRecord(
            id=0,
            inputs={"text": "The text data"},
            annotation="T",
            annotation_agent="test",
            status="Discarded",
        ),
        name=dataset_name,
    )
    df = rubrix.load(dataset_name)
    records = df.to_dict(orient="records")
    assert len(records) == 1
    assert records[0]["status"] == "Discarded"
示例#3
0
def test_not_found_response(monkeypatch):
    mocking_client(monkeypatch)
    not_found_match = "Not found error. The API answered with a 404 code"
    with pytest.raises(Exception, match=not_found_match):
        rubrix.snapshots(name="not_found")

    with pytest.raises(Exception, match=not_found_match):
        rubrix.load(name="not-found")

    with pytest.raises(Exception, match=not_found_match):
        rubrix.load(name="not-found", snapshot="blabla")
示例#4
0
def test_load_with_ids_list(monkeypatch):
    mocking_client(monkeypatch)
    dataset = "test_load_with_ids_list"
    client.delete(f"/api/datasets/{dataset}")
    sleep(1)
    api_ds_prefix = f"/api/datasets/{dataset}"

    expected_data = 100
    create_some_data_for_text_classification(dataset, n=expected_data)
    response = client.post(f"{api_ds_prefix}/snapshots")
    assert response.status_code == 200
    snapshot = DatasetSnapshot.from_dict(response.json())
    ds = rubrix.load(name=dataset, ids=[3, 5])
    assert len(ds) == 2

    ds = rubrix.load(name=dataset, ids=[3, 5], snapshot=snapshot.id)
    assert len(ds) == 100
示例#5
0
def test_load_limits(monkeypatch):
    mocking_client(monkeypatch)
    dataset = "test_load_limits"
    api_ds_prefix = f"/api/datasets/{dataset}"
    client.delete(api_ds_prefix)

    create_some_data_for_text_classification(dataset, 50)
    response = client.post(f"{api_ds_prefix}/snapshots")

    limit_data_to = 10
    ds = rubrix.load(name=dataset, limit=limit_data_to)
    assert isinstance(ds, pandas.DataFrame)
    assert len(ds) == limit_data_to

    snapshot = rubrix.snapshots(dataset)[0]
    ds = rubrix.load(name=dataset, snapshot=snapshot.id, limit=limit_data_to)
    assert isinstance(ds, pandas.DataFrame)
    assert len(ds) == limit_data_to
示例#6
0
def test_delete_dataset(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "test_delete_dataset"
    client.delete(f"/api/datasets/{dataset_name}")

    rubrix.log(
        TextClassificationRecord(
            id=0,
            inputs={"text": "The text data"},
            annotation_agent="test",
            annotation=["T"],
        ),
        name=dataset_name,
    )
    rubrix.load(name=dataset_name)
    rubrix.delete(name=dataset_name)
    sleep(1)
    with pytest.raises(
            Exception,
            match="Not found error. The API answered with a 404 code"):
        rubrix.load(name=dataset_name)
示例#7
0
def test_text_classifier_with_inputs_list(monkeypatch):
    mocking_client(monkeypatch)
    dataset = "test_text_classifier_with_inputs_list"
    client.delete(f"/api/datasets/{dataset}")

    expected_inputs = ["A", "List", "of", "values"]
    rubrix.log(
        TextClassificationRecord(
            id=0,
            inputs=expected_inputs,
            annotation_agent="test",
            annotation=["T"],
        ),
        name=dataset,
    )

    df = rubrix.load(name=dataset)
    records = df.to_dict(orient="records")
    assert len(records) == 1
    assert records[0]["inputs"]["text"] == expected_inputs