def test_log_with_annotation(monkeypatch): mocking_client(monkeypatch) dataset_name = "test_log_with_annotation" client.delete(f"/api/datasets/{dataset_name}") rubrix.log( TextClassificationRecord( id=0, inputs={"text": "The text data"}, annotation="T", annotation_agent="test", ), name=dataset_name, ) df = rubrix.load(dataset_name) records = df.to_dict(orient="records") assert len(records) == 1 assert records[0]["status"] == "Validated" rubrix.log( TextClassificationRecord( id=0, inputs={"text": "The text data"}, annotation="T", annotation_agent="test", status="Discarded", ), name=dataset_name, ) df = rubrix.load(dataset_name) records = df.to_dict(orient="records") assert len(records) == 1 assert records[0]["status"] == "Discarded"
def test_text_classification(mock_response_200, mock_response_text): """Testing text classification with log function It checks a Response is generated. Parameters ---------- mock_response_200 Mocked correct http response, emulating API init mock_response_text Mocked response given by the sync method, emulating the log of data """ records = [ TextClassificationRecord( inputs={"review_body": "increible test"}, prediction=[("test", 0.9), ("test2", 0.1)], annotation="test", metadata={"product_category": "test de pytest"}, id="test", ) ] assert (rubrix.log( name="test", records=records, tags={ "type": "sentiment classifier", "lang": "spanish" }, ) == BulkResponse(dataset="test", processed=500, failed=0))
def test_snapshots(monkeypatch): mocking_client(monkeypatch) dataset = "test_create_dataset_snapshot" client.delete(f"/api/datasets/{dataset}") sleep(1) api_ds_prefix = f"/api/datasets/{dataset}" expected_data = 100 create_some_data_for_text_classification(dataset, n=expected_data) response = client.post(f"{api_ds_prefix}/snapshots") assert response.status_code == 200 snapshots = rubrix.snapshots(dataset) assert len(snapshots) > 0 for snapshot in snapshots: assert snapshot.task == TaskType.text_classification assert snapshot.id assert snapshot.creation_date ds = rubrix.load(name=dataset) assert isinstance(ds, pandas.DataFrame) assert len(ds) == expected_data records = list( map(lambda r: TextClassificationRecord(**r), ds.to_dict(orient="records"))) ds = rubrix.load(name=dataset, snapshot=snapshots[0].id) assert isinstance(ds, pandas.DataFrame)
def text_classification_mapper(inputs, outputs): return TextClassificationRecord( inputs=inputs, prediction=[(label, score) for label, score in zip( outputs.get("labels", []), outputs.get("probabilities", []))], event_timestamp=datetime.datetime.now(), )
def test_single_record(monkeypatch): mocking_client(monkeypatch) dataset_name = "test_log_single_records" client.delete(f"/api/datasets/{dataset_name}") item = TextClassificationRecord( inputs={"text": "This is a single record. Only this. No more."}) rubrix.log(item, name=dataset_name)
def test_log_records_with_too_long_text(monkeypatch): mocking_client(monkeypatch) dataset_name = "test_log_records_with_too_long_text" client.delete(f"/api/datasets/{dataset_name}") item = TextClassificationRecord( inputs={"text": "This is a toooooo long text\n" * 10000}) rubrix.log([item], name=dataset_name)
def test_create_ds_with_wrong_name(monkeypatch): mocking_client(monkeypatch) dataset_name = "Test Create_ds_with_wrong_name" client.delete(f"/api/datasets/{dataset_name}") with pytest.raises( Exception, match="msg='string does not match regex", ): rubrix.log( TextClassificationRecord(inputs={"text": "The text data"}, ), name=dataset_name, )
def test_wrong_response(mock_response_200, mock_wrong_bulk_response): rubrix._client = None with pytest.raises( Exception, match= "Connection error: API is not responding. The API answered with", ): rubrix.log( name="dataset", records=[ TextClassificationRecord(inputs={"text": "The textual info"}) ], tags={"env": "Test"}, )
def test_text_classifier_with_inputs_list(monkeypatch): mocking_client(monkeypatch) dataset = "test_text_classifier_with_inputs_list" client.delete(f"/api/datasets/{dataset}") expected_inputs = ["A", "List", "of", "values"] rubrix.log( TextClassificationRecord( id=0, inputs=expected_inputs, annotation_agent="test", annotation=["T"], ), name=dataset, ) df = rubrix.load(name=dataset) records = df.to_dict(orient="records") assert len(records) == 1 assert records[0]["inputs"]["text"] == expected_inputs
def test_delete_dataset(monkeypatch): mocking_client(monkeypatch) dataset_name = "test_delete_dataset" client.delete(f"/api/datasets/{dataset_name}") rubrix.log( TextClassificationRecord( id=0, inputs={"text": "The text data"}, annotation_agent="test", annotation=["T"], ), name=dataset_name, ) rubrix.load(name=dataset_name) rubrix.delete(name=dataset_name) sleep(1) with pytest.raises( Exception, match="Not found error. The API answered with a 404 code"): rubrix.load(name=dataset_name)
def test_info_message(mock_response_200, mock_response_text, caplog): """Testing initialization info message Parameters ---------- mock_response_200 Mocked correct http response, emulating API init mock_response_text Mocked response given by the sync method, emulating the log of data caplog Captures the logging output """ rubrix._client = None # Force client initialization caplog.set_level(logging.INFO) records = [ TextClassificationRecord( inputs={"review_body": "increible test"}, prediction=[("test", 0.9), ("test2", 0.1)], annotation="test", metadata={"product_category": "test de pytest"}, id="test", ) ] rubrix.log( name="test", records=records, tags={ "type": "sentiment classifier", "lang": "spanish" }, ) print(caplog.text) assert "Rubrix has been initialized on http://localhost:6900" in caplog.text
def test_text_classification_record_to_sdk(annotation): token_attributions = [ TokenAttributions(token="test", attributions={ "label1": 1.0, "label2": 2.0 }) ] record = TextClassificationRecord( inputs={"text": "test"}, prediction=[("label1", 0.5), ("label2", 0.5)], annotation=annotation, prediction_agent="test_model", annotation_agent="test_annotator", multi_label=True, explanation={"text": token_attributions}, id=1, metadata={"metadata": "test"}, status="Default", event_timestamp=datetime.datetime(2000, 1, 1), ) sdk_record = RubrixClient._text_classification_record_to_sdk(record) assert sdk_record.event_timestamp == datetime.datetime(2000, 1, 1)
def generator(items: int = 10) -> Iterable[TextClassificationRecord]: for i in range(0, items): yield TextClassificationRecord(id=i, inputs={"text": "The text data"})