def main(spacy_file, collection_name, doc_limit): nlp = spacy.load("en_core_web_sm") docs_bin_obj = spacy.tokens.DocBin() docs_bin_obj.from_disk(spacy_file) doc_iter = docs_bin_obj.get_docs(vocab=nlp.vocab) record_list = [] i = 0 for doc_obj in doc_iter: labelled_entities = [] for ent in doc_obj.ents: labelled_entities += [(ent.label_, ent.start_char, ent.end_char)] record = rb.TokenClassificationRecord( text=doc_obj.text, tokens=[token.text for token in doc_obj], prediction=labelled_entities, prediction_agent="ohnlp.custom_rules.provider", metadata=doc_obj.user_data) record_list += [record] i += 1 if doc_limit is not None and i >= doc_limit: break rb.log(records=record_list, name=collection_name)
def test_log_with_annotation(monkeypatch): mocking_client(monkeypatch) dataset_name = "test_log_with_annotation" client.delete(f"/api/datasets/{dataset_name}") rubrix.log( TextClassificationRecord( id=0, inputs={"text": "The text data"}, annotation="T", annotation_agent="test", ), name=dataset_name, ) df = rubrix.load(dataset_name) records = df.to_dict(orient="records") assert len(records) == 1 assert records[0]["status"] == "Validated" rubrix.log( TextClassificationRecord( id=0, inputs={"text": "The text data"}, annotation="T", annotation_agent="test", status="Discarded", ), name=dataset_name, ) df = rubrix.load(dataset_name) records = df.to_dict(orient="records") assert len(records) == 1 assert records[0]["status"] == "Discarded"
def test_single_record(monkeypatch): mocking_client(monkeypatch) dataset_name = "test_log_single_records" client.delete(f"/api/datasets/{dataset_name}") item = TextClassificationRecord( inputs={"text": "This is a single record. Only this. No more."}) rubrix.log(item, name=dataset_name)
def test_log_records_with_too_long_text(monkeypatch): mocking_client(monkeypatch) dataset_name = "test_log_records_with_too_long_text" client.delete(f"/api/datasets/{dataset_name}") item = TextClassificationRecord( inputs={"text": "This is a toooooo long text\n" * 10000}) rubrix.log([item], name=dataset_name)
def _log_to_rubrix(self, inputs: List[Dict[str, Any]], outputs: List[Dict[str, Any]], url: str, **tags): records = [ record for _inputs, _outputs in zip(inputs, outputs) for record in [self._records_mapper(_inputs, _outputs)] if record ] if records: for r in records: r.prediction_agent = url rubrix.log(records=records, name=self._dataset, tags=tags)
def test_log_with_generator(monkeypatch): mocking_client(monkeypatch) dataset_name = "test_log_with_generator" client.delete(f"/api/datasets/{dataset_name}") def generator(items: int = 10) -> Iterable[TextClassificationRecord]: for i in range(0, items): yield TextClassificationRecord(id=i, inputs={"text": "The text data"}) rubrix.log(generator(), name=dataset_name)
def main(config): rb.init(api_url=config["rubrix_uri"], api_key=config["rubrix_api_key"]) rb.set_workspace(config["rubrix_workspace"]) connection_uri = config["connection_uri"] engine = sa.create_engine(connection_uri) result_list = [] with engine.connect() as connection: # Long Covid Patients; notes 1 month after DX or test, Outpatient physician notes query = """ select pn.*, e.service_delivery_location from sbm_covid19_documents.physician_notes pn join sbm_covid19_hi.PH_F_Encounter e on pn.encounter_id = e.encounter_number where document_id in ( select distinct document_id from sbm_covid19_documents.physician_notes p join sbm_covid19_analytics_build.pui_covid_result_overview pcro on p.patient_id = pcro.mrn join sbm_covid19_hi_cdm_build.map2_condition_occurrence co on pcro.person_id = co.person_id and co.condition_source_concept_code = 'U09.9' where p.document_datetime >= pcro.positive_datetime + interval '1 month' and pcro.positive_datetime is not null) and document_type not in ('Ambulatory Patient Summary') and e.classification_display = 'Outpatient' """ cursor = connection.execute(query) report_list = [] for row in cursor: result_list += [row["file_text"]] meta_data_dict = {} columns = ["patient_id", "mapped_document_type", "document_type", "service_delivery_location", "patient_id", "encounter_id"] for column in columns: meta_data_dict[column] = row[column] rb_obj = rb.TextClassificationRecord(inputs={"text": row["file_text"]}, metadata=meta_data_dict, id=row["document_id"], multi_label=True) report_list += [rb_obj] reports_to_commit = [] for i in range(len(report_list)): reports_to_commit += [report_list[i]] if i > 0 and i % 50 == 0: rb.log(records=reports_to_commit, name="long_covid_with_metadata") reports_to_commit = [] else: pass if len(reports_to_commit): rb.log(records=reports_to_commit, name="long_covid_with_metadata")
def test_create_ds_with_wrong_name(monkeypatch): mocking_client(monkeypatch) dataset_name = "Test Create_ds_with_wrong_name" client.delete(f"/api/datasets/{dataset_name}") with pytest.raises( Exception, match="msg='string does not match regex", ): rubrix.log( TextClassificationRecord(inputs={"text": "The text data"}, ), name=dataset_name, )
def test_empty_records(mock_response_200): """Testing classification with empty record list It checks an Exception is raised, with the corresponding message. Parameters ---------- mock_response_200 Mocked correct http response, emulating API init """ with pytest.raises(Exception, match="Empty record list has been passed as argument."): rubrix.log(name="test", records=[])
def test_unknow_record_type(mock_response_200): """Testing classification with unknown record type It checks an Exception is raised, with the corresponding message. Parameters ---------- mock_response_200 Mocked correct http response, emulating API init """ with pytest.raises(Exception, match="Unknown record type passed as argument."): rubrix.log(name="test", records=["12"])
def test_wrong_response(mock_response_200, mock_wrong_bulk_response): rubrix._client = None with pytest.raises( Exception, match= "Connection error: API is not responding. The API answered with", ): rubrix.log( name="dataset", records=[ TextClassificationRecord(inputs={"text": "The textual info"}) ], tags={"env": "Test"}, )
def test_token_classification(mock_response_200, mock_response_token): """Testing token classification with log function It checks a Response is generated. Parameters ---------- mock_response_200 Mocked correct http response, emulating API init mock_response_token Mocked response given by the sync method, emulating the log of data """ records = [ TokenClassificationRecord( text="Super test", tokens=["Super", "test"], prediction=[("test", 6, 10)], annotation=[("test", 6, 10)], prediction_agent="spacy", annotation_agent="recognai", metadata={"model": "spacy_es_core_news_sm"}, id=1, ) ] assert (rubrix.log( name="test", records=records[0], tags={ "type": "sentiment classifier", "lang": "spanish" }, ) == BulkResponse(dataset="test", processed=500, failed=0))
def test_text_classification(mock_response_200, mock_response_text): """Testing text classification with log function It checks a Response is generated. Parameters ---------- mock_response_200 Mocked correct http response, emulating API init mock_response_text Mocked response given by the sync method, emulating the log of data """ records = [ TextClassificationRecord( inputs={"review_body": "increible test"}, prediction=[("test", 0.9), ("test2", 0.1)], annotation="test", metadata={"product_category": "test de pytest"}, id="test", ) ] assert (rubrix.log( name="test", records=records, tags={ "type": "sentiment classifier", "lang": "spanish" }, ) == BulkResponse(dataset="test", processed=500, failed=0))
def test_text_classifier_with_inputs_list(monkeypatch): mocking_client(monkeypatch) dataset = "test_text_classifier_with_inputs_list" client.delete(f"/api/datasets/{dataset}") expected_inputs = ["A", "List", "of", "values"] rubrix.log( TextClassificationRecord( id=0, inputs=expected_inputs, annotation_agent="test", annotation=["T"], ), name=dataset, ) df = rubrix.load(name=dataset) records = df.to_dict(orient="records") assert len(records) == 1 assert records[0]["inputs"]["text"] == expected_inputs
def test_delete_dataset(monkeypatch): mocking_client(monkeypatch) dataset_name = "test_delete_dataset" client.delete(f"/api/datasets/{dataset_name}") rubrix.log( TextClassificationRecord( id=0, inputs={"text": "The text data"}, annotation_agent="test", annotation=["T"], ), name=dataset_name, ) rubrix.load(name=dataset_name) rubrix.delete(name=dataset_name) sleep(1) with pytest.raises( Exception, match="Not found error. The API answered with a 404 code"): rubrix.load(name=dataset_name)
def test_info_message(mock_response_200, mock_response_text, caplog): """Testing initialization info message Parameters ---------- mock_response_200 Mocked correct http response, emulating API init mock_response_text Mocked response given by the sync method, emulating the log of data caplog Captures the logging output """ rubrix._client = None # Force client initialization caplog.set_level(logging.INFO) records = [ TextClassificationRecord( inputs={"review_body": "increible test"}, prediction=[("test", 0.9), ("test2", 0.1)], annotation="test", metadata={"product_category": "test de pytest"}, id="test", ) ] rubrix.log( name="test", records=records, tags={ "type": "sentiment classifier", "lang": "spanish" }, ) print(caplog.text) assert "Rubrix has been initialized on http://localhost:6900" in caplog.text
def test_no_name(mock_response_200): """Testing classification with no input name It checks an Exception is raised, with the corresponding message. Parameters ---------- mock_response_200 Mocked correct http response, emulating API init """ with pytest.raises( Exception, match="Empty project name has been passed as argument."): assert rubrix.log(name="", records=cast(TextClassificationRecord, None))
def test_log_something(monkeypatch): mocking_client(monkeypatch) dataset_name = "test-dataset" client.delete(f"/api/datasets/{dataset_name}") response = rubrix.log( name=dataset_name, records=rubrix.TextClassificationRecord( inputs={"text": "This is a test"}), ) assert response.processed == 1 assert response.failed == 0 response = client.post( f"/api/datasets/{dataset_name}/TextClassification:search") results = TextClassificationSearchResults.from_dict(response.json()) assert results.total == 1 assert len(results.records) == 1 assert results.records[0].inputs["text"] == "This is a test"
def test_token_classification_spans(monkeypatch): mocking_client(monkeypatch) dataset = "test_token_classification_with_consecutive_spans" texto = "Esto es una prueba" item = rubrix.TokenClassificationRecord( text=texto, tokens=texto.split(), prediction=[("test", 1, 2)], # Inicio y fin son consecutivos prediction_agent="test", ) with pytest.raises( Exception, match=r"Defined offset \[s\] is a misaligned entity mention"): rubrix.log(item, name=dataset) item.prediction = [("test", 0, 6)] with pytest.raises( Exception, match=r"Defined offset \[Esto e\] is a misaligned entity mention"): rubrix.log(item, name=dataset) item.prediction = [("test", 0, 4)] rubrix.log(item, name=dataset)
def main(): classifier = loading_model() # Chached function, loading on top of the app # Header title, _, subtitle = st.beta_columns((2.5, 0.3, 0.7)) title.title("How to log your ML experiments and annotations with Rubrix") with subtitle: st.write("") subtitle.subheader("A Web App by [Recognai](https://www.recogn.ai)") # First text body st.markdown("") # empty space st.markdown( """Hey there, welcome! This demo will show you how to keep track of your live model predictions using Rubrix. """) st.markdown( """Lets make a quick experiment: an NLP model that guesses which theme a text is talking about. We are using a zero-shot classifier based on [*SqueezeBERT*](https://huggingface.co/typeform/squeezebert-mnli)""") text_input = st.text_area( """For example: 'I love to watch cycling competitions!'""") confidence_threshold = ( 0.5 # Starting value of the treshold, may be changed with the slider ) if text_input: # Making model predictions and storing them into a dataframe preds = classifier( text_input, candidate_labels=CANDIDATE_LABELS, hypothesis_template="This text is about {}.", multi_class=True, ) df = pd.DataFrame({ "index": preds["labels"], "confidence": [s for s in preds["scores"]], "score": [s * 100 for s in preds["scores"]], }).set_index("index") # Confidence threshold slider, changes the green categories in the graph and the categories shown # in the multiselect so users has the classes above the threshold as preanottations confidence_threshold = st.slider( "We can select a threshold to decide which confidence must be obtained to consider it a prediction.", 0.0, 1.0, 0.5, 0.01, ) # Predictions according to the threshold predictions = populating_predictions(df, confidence_threshold) df_table, _, bar_chart = st.beta_columns((1.2, 0.1, 2)) # Class-Probabilities table with df_table: # Probabilities field st.dataframe(df[["score"]]) # Class-Probabilities Chart with Confidence with bar_chart: bar_chart = bar_chart_generator(df, confidence_threshold) st.altair_chart(bar_chart, use_container_width=True) # Selection of the annotated labels selected_labels = st.multiselect( label= """With the given threshold, these are the categories predicted. You can change them, and your final selection will be logged as "user-validated" annotations (i.e., ground-truth labels).""", options=df.reset_index()["index"].tolist(), default=predictions, ) st.markdown( """Once you are happy with the input and the categories annotated, press the button to log your data into Rubrix.""") log_button = center_button( "Log {} predictions with {} annotations".format( len(df["score"]), len(selected_labels))) if log_button: # Population of labels labels = [] for _, row in df.reset_index().iterrows(): labels.append((row["index"], row["confidence"])) # Creation of the classification record item = rubrix.TextClassificationRecord( inputs={"text": text_input}, prediction=labels, prediction_agent="typeform/squeezebert-mnli", annotation=selected_labels, annotation_agent="streamlit-user", multi_label=True, event_timestamp=datetime.datetime.now(), metadata={"model": "typeform/squeezebert-mnli"}) dataset_name = "multilabel_text_classification" rubrix.log(name=dataset_name, records=item) api_url = os.getenv("RUBRIX_API_URL", "http://localhost:6900") # Pretty-print of the logged item st.markdown( f"""Your data has been logged! You can view your dataset in [{api_url}/{dataset_name}]({api_url}/{dataset_name}), which has logged this object right below:""") st.json(item.dict()) st.markdown(""" Logging this predictions into Rubrix can be done with a few commands in your Python scripts.""" ) # By default, Rubrix will connect to http://localhost:6900 with no security. st.code( """ import rubrix item = rubrix.TextClassificationRecord( inputs={"text": text_input}, prediction=labels, prediction_agent="typeform/squeezebert-mnli", annotation=selected_labels, annotation_agent="streamlit-user", multi_label=True, event_timestamp=datetime.datetime.now(), metadata={"model": "typeform/squeezebert-mnli"} ) rubrix.log(name="experiment_name", records=item) """, language="python", )
def test_passing_wrong_iterable_data(monkeypatch): mocking_client(monkeypatch) dataset_name = "test_log_single_records" client.delete(f"/api/datasets/{dataset_name}") with pytest.raises(Exception, match="Unknown record type passed"): rubrix.log({"a": "010", "b": 100}, name=dataset_name)