def test_snapshots(monkeypatch): mocking_client(monkeypatch) dataset = "test_create_dataset_snapshot" client.delete(f"/api/datasets/{dataset}") sleep(1) api_ds_prefix = f"/api/datasets/{dataset}" expected_data = 100 create_some_data_for_text_classification(dataset, n=expected_data) response = client.post(f"{api_ds_prefix}/snapshots") assert response.status_code == 200 snapshots = rubrix.snapshots(dataset) assert len(snapshots) > 0 for snapshot in snapshots: assert snapshot.task == TaskType.text_classification assert snapshot.id assert snapshot.creation_date ds = rubrix.load(name=dataset) assert isinstance(ds, pandas.DataFrame) assert len(ds) == expected_data records = list( map(lambda r: TextClassificationRecord(**r), ds.to_dict(orient="records"))) ds = rubrix.load(name=dataset, snapshot=snapshots[0].id) assert isinstance(ds, pandas.DataFrame)
def test_log_with_annotation(monkeypatch): mocking_client(monkeypatch) dataset_name = "test_log_with_annotation" client.delete(f"/api/datasets/{dataset_name}") rubrix.log( TextClassificationRecord( id=0, inputs={"text": "The text data"}, annotation="T", annotation_agent="test", ), name=dataset_name, ) df = rubrix.load(dataset_name) records = df.to_dict(orient="records") assert len(records) == 1 assert records[0]["status"] == "Validated" rubrix.log( TextClassificationRecord( id=0, inputs={"text": "The text data"}, annotation="T", annotation_agent="test", status="Discarded", ), name=dataset_name, ) df = rubrix.load(dataset_name) records = df.to_dict(orient="records") assert len(records) == 1 assert records[0]["status"] == "Discarded"
def test_not_found_response(monkeypatch): mocking_client(monkeypatch) not_found_match = "Not found error. The API answered with a 404 code" with pytest.raises(Exception, match=not_found_match): rubrix.snapshots(name="not_found") with pytest.raises(Exception, match=not_found_match): rubrix.load(name="not-found") with pytest.raises(Exception, match=not_found_match): rubrix.load(name="not-found", snapshot="blabla")
def test_load_with_ids_list(monkeypatch): mocking_client(monkeypatch) dataset = "test_load_with_ids_list" client.delete(f"/api/datasets/{dataset}") sleep(1) api_ds_prefix = f"/api/datasets/{dataset}" expected_data = 100 create_some_data_for_text_classification(dataset, n=expected_data) response = client.post(f"{api_ds_prefix}/snapshots") assert response.status_code == 200 snapshot = DatasetSnapshot.from_dict(response.json()) ds = rubrix.load(name=dataset, ids=[3, 5]) assert len(ds) == 2 ds = rubrix.load(name=dataset, ids=[3, 5], snapshot=snapshot.id) assert len(ds) == 100
def test_load_limits(monkeypatch): mocking_client(monkeypatch) dataset = "test_load_limits" api_ds_prefix = f"/api/datasets/{dataset}" client.delete(api_ds_prefix) create_some_data_for_text_classification(dataset, 50) response = client.post(f"{api_ds_prefix}/snapshots") limit_data_to = 10 ds = rubrix.load(name=dataset, limit=limit_data_to) assert isinstance(ds, pandas.DataFrame) assert len(ds) == limit_data_to snapshot = rubrix.snapshots(dataset)[0] ds = rubrix.load(name=dataset, snapshot=snapshot.id, limit=limit_data_to) assert isinstance(ds, pandas.DataFrame) assert len(ds) == limit_data_to
def test_delete_dataset(monkeypatch): mocking_client(monkeypatch) dataset_name = "test_delete_dataset" client.delete(f"/api/datasets/{dataset_name}") rubrix.log( TextClassificationRecord( id=0, inputs={"text": "The text data"}, annotation_agent="test", annotation=["T"], ), name=dataset_name, ) rubrix.load(name=dataset_name) rubrix.delete(name=dataset_name) sleep(1) with pytest.raises( Exception, match="Not found error. The API answered with a 404 code"): rubrix.load(name=dataset_name)
def test_text_classifier_with_inputs_list(monkeypatch): mocking_client(monkeypatch) dataset = "test_text_classifier_with_inputs_list" client.delete(f"/api/datasets/{dataset}") expected_inputs = ["A", "List", "of", "values"] rubrix.log( TextClassificationRecord( id=0, inputs=expected_inputs, annotation_agent="test", annotation=["T"], ), name=dataset, ) df = rubrix.load(name=dataset) records = df.to_dict(orient="records") assert len(records) == 1 assert records[0]["inputs"]["text"] == expected_inputs