def main(spacy_file, collection_name, doc_limit):

    nlp = spacy.load("en_core_web_sm")
    docs_bin_obj = spacy.tokens.DocBin()
    docs_bin_obj.from_disk(spacy_file)

    doc_iter = docs_bin_obj.get_docs(vocab=nlp.vocab)
    record_list = []
    i = 0
    for doc_obj in doc_iter:
        labelled_entities = []
        for ent in doc_obj.ents:
            labelled_entities += [(ent.label_, ent.start_char, ent.end_char)]

        record = rb.TokenClassificationRecord(
            text=doc_obj.text,
            tokens=[token.text for token in doc_obj],
            prediction=labelled_entities,
            prediction_agent="ohnlp.custom_rules.provider",
            metadata=doc_obj.user_data)
        record_list += [record]

        i += 1

        if doc_limit is not None and i >= doc_limit:
            break

    rb.log(records=record_list, name=collection_name)
Esempio n. 2
0
def test_log_with_annotation(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "test_log_with_annotation"
    client.delete(f"/api/datasets/{dataset_name}")
    rubrix.log(
        TextClassificationRecord(
            id=0,
            inputs={"text": "The text data"},
            annotation="T",
            annotation_agent="test",
        ),
        name=dataset_name,
    )

    df = rubrix.load(dataset_name)
    records = df.to_dict(orient="records")
    assert len(records) == 1
    assert records[0]["status"] == "Validated"

    rubrix.log(
        TextClassificationRecord(
            id=0,
            inputs={"text": "The text data"},
            annotation="T",
            annotation_agent="test",
            status="Discarded",
        ),
        name=dataset_name,
    )
    df = rubrix.load(dataset_name)
    records = df.to_dict(orient="records")
    assert len(records) == 1
    assert records[0]["status"] == "Discarded"
Esempio n. 3
0
def test_single_record(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "test_log_single_records"
    client.delete(f"/api/datasets/{dataset_name}")
    item = TextClassificationRecord(
        inputs={"text": "This is a single record. Only this. No more."})

    rubrix.log(item, name=dataset_name)
Esempio n. 4
0
def test_log_records_with_too_long_text(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "test_log_records_with_too_long_text"
    client.delete(f"/api/datasets/{dataset_name}")
    item = TextClassificationRecord(
        inputs={"text": "This is a toooooo long text\n" * 10000})

    rubrix.log([item], name=dataset_name)
Esempio n. 5
0
    def _log_to_rubrix(self, inputs: List[Dict[str, Any]],
                       outputs: List[Dict[str, Any]], url: str, **tags):
        records = [
            record for _inputs, _outputs in zip(inputs, outputs)
            for record in [self._records_mapper(_inputs, _outputs)] if record
        ]

        if records:
            for r in records:
                r.prediction_agent = url
            rubrix.log(records=records, name=self._dataset, tags=tags)
Esempio n. 6
0
def test_log_with_generator(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "test_log_with_generator"
    client.delete(f"/api/datasets/{dataset_name}")

    def generator(items: int = 10) -> Iterable[TextClassificationRecord]:
        for i in range(0, items):
            yield TextClassificationRecord(id=i,
                                           inputs={"text": "The text data"})

    rubrix.log(generator(), name=dataset_name)
Esempio n. 7
0
def main(config):

    rb.init(api_url=config["rubrix_uri"], api_key=config["rubrix_api_key"])
    rb.set_workspace(config["rubrix_workspace"])

    connection_uri = config["connection_uri"]
    engine = sa.create_engine(connection_uri)
    result_list = []

    with engine.connect() as connection:

        # Long Covid Patients; notes 1 month after DX or test, Outpatient physician notes
        query = """
        select pn.*, e.service_delivery_location from sbm_covid19_documents.physician_notes pn
    join sbm_covid19_hi.PH_F_Encounter e on pn.encounter_id = e.encounter_number
    where document_id in (
        select distinct document_id from sbm_covid19_documents.physician_notes p
            join sbm_covid19_analytics_build.pui_covid_result_overview pcro
                on p.patient_id = pcro.mrn
            join sbm_covid19_hi_cdm_build.map2_condition_occurrence co on pcro.person_id = co.person_id
            and co.condition_source_concept_code = 'U09.9'
        where p.document_datetime >= pcro.positive_datetime + interval '1 month' and pcro.positive_datetime is not null)
    and document_type not in ('Ambulatory Patient Summary')
    and e.classification_display = 'Outpatient'
        """

        cursor = connection.execute(query)
        report_list = []
        for row in cursor:

            result_list += [row["file_text"]]
            meta_data_dict = {}
            columns = ["patient_id", "mapped_document_type", "document_type", "service_delivery_location", "patient_id",
                       "encounter_id"]
            for column in columns:
                meta_data_dict[column] = row[column]

            rb_obj = rb.TextClassificationRecord(inputs={"text": row["file_text"]}, metadata=meta_data_dict,
                                                 id=row["document_id"], multi_label=True)
            report_list += [rb_obj]

        reports_to_commit = []
        for i in range(len(report_list)):

            reports_to_commit += [report_list[i]]
            if i > 0 and i % 50 == 0:
                rb.log(records=reports_to_commit, name="long_covid_with_metadata")
                reports_to_commit = []
            else:
                pass

        if len(reports_to_commit):
            rb.log(records=reports_to_commit, name="long_covid_with_metadata")
Esempio n. 8
0
def test_create_ds_with_wrong_name(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "Test Create_ds_with_wrong_name"
    client.delete(f"/api/datasets/{dataset_name}")

    with pytest.raises(
            Exception,
            match="msg='string does not match regex",
    ):
        rubrix.log(
            TextClassificationRecord(inputs={"text": "The text data"}, ),
            name=dataset_name,
        )
Esempio n. 9
0
def test_empty_records(mock_response_200):
    """Testing classification with empty record list

    It checks an Exception is raised, with the corresponding message.

    Parameters
    ----------
    mock_response_200
        Mocked correct http response, emulating API init
    """

    with pytest.raises(Exception,
                       match="Empty record list has been passed as argument."):
        rubrix.log(name="test", records=[])
Esempio n. 10
0
def test_unknow_record_type(mock_response_200):
    """Testing classification with unknown record type

    It checks an Exception is raised, with the corresponding message.

    Parameters
    ----------
    mock_response_200
        Mocked correct http response, emulating API init
    """

    with pytest.raises(Exception,
                       match="Unknown record type passed as argument."):
        rubrix.log(name="test", records=["12"])
Esempio n. 11
0
def test_wrong_response(mock_response_200, mock_wrong_bulk_response):
    rubrix._client = None
    with pytest.raises(
            Exception,
            match=
            "Connection error: API is not responding. The API answered with",
    ):
        rubrix.log(
            name="dataset",
            records=[
                TextClassificationRecord(inputs={"text": "The textual info"})
            ],
            tags={"env": "Test"},
        )
Esempio n. 12
0
def test_token_classification(mock_response_200, mock_response_token):
    """Testing token classification with log function

    It checks a Response is generated.

    Parameters
    ----------
    mock_response_200
        Mocked correct http response, emulating API init
    mock_response_token
        Mocked response given by the sync method, emulating the log of data
    """
    records = [
        TokenClassificationRecord(
            text="Super test",
            tokens=["Super", "test"],
            prediction=[("test", 6, 10)],
            annotation=[("test", 6, 10)],
            prediction_agent="spacy",
            annotation_agent="recognai",
            metadata={"model": "spacy_es_core_news_sm"},
            id=1,
        )
    ]

    assert (rubrix.log(
        name="test",
        records=records[0],
        tags={
            "type": "sentiment classifier",
            "lang": "spanish"
        },
    ) == BulkResponse(dataset="test", processed=500, failed=0))
Esempio n. 13
0
def test_text_classification(mock_response_200, mock_response_text):
    """Testing text classification with log function

    It checks a Response is generated.

    Parameters
    ----------
    mock_response_200
        Mocked correct http response, emulating API init
    mock_response_text
        Mocked response given by the sync method, emulating the log of data
    """
    records = [
        TextClassificationRecord(
            inputs={"review_body": "increible test"},
            prediction=[("test", 0.9), ("test2", 0.1)],
            annotation="test",
            metadata={"product_category": "test de pytest"},
            id="test",
        )
    ]

    assert (rubrix.log(
        name="test",
        records=records,
        tags={
            "type": "sentiment classifier",
            "lang": "spanish"
        },
    ) == BulkResponse(dataset="test", processed=500, failed=0))
Esempio n. 14
0
def test_text_classifier_with_inputs_list(monkeypatch):
    mocking_client(monkeypatch)
    dataset = "test_text_classifier_with_inputs_list"
    client.delete(f"/api/datasets/{dataset}")

    expected_inputs = ["A", "List", "of", "values"]
    rubrix.log(
        TextClassificationRecord(
            id=0,
            inputs=expected_inputs,
            annotation_agent="test",
            annotation=["T"],
        ),
        name=dataset,
    )

    df = rubrix.load(name=dataset)
    records = df.to_dict(orient="records")
    assert len(records) == 1
    assert records[0]["inputs"]["text"] == expected_inputs
Esempio n. 15
0
def test_delete_dataset(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "test_delete_dataset"
    client.delete(f"/api/datasets/{dataset_name}")

    rubrix.log(
        TextClassificationRecord(
            id=0,
            inputs={"text": "The text data"},
            annotation_agent="test",
            annotation=["T"],
        ),
        name=dataset_name,
    )
    rubrix.load(name=dataset_name)
    rubrix.delete(name=dataset_name)
    sleep(1)
    with pytest.raises(
            Exception,
            match="Not found error. The API answered with a 404 code"):
        rubrix.load(name=dataset_name)
Esempio n. 16
0
def test_info_message(mock_response_200, mock_response_text, caplog):
    """Testing initialization info message

    Parameters
    ----------
    mock_response_200
        Mocked correct http response, emulating API init
    mock_response_text
        Mocked response given by the sync method, emulating the log of data
    caplog
        Captures the logging output
    """

    rubrix._client = None  # Force client initialization
    caplog.set_level(logging.INFO)

    records = [
        TextClassificationRecord(
            inputs={"review_body": "increible test"},
            prediction=[("test", 0.9), ("test2", 0.1)],
            annotation="test",
            metadata={"product_category": "test de pytest"},
            id="test",
        )
    ]

    rubrix.log(
        name="test",
        records=records,
        tags={
            "type": "sentiment classifier",
            "lang": "spanish"
        },
    )

    print(caplog.text)

    assert "Rubrix has been initialized on http://localhost:6900" in caplog.text
Esempio n. 17
0
def test_no_name(mock_response_200):
    """Testing classification with no input name

    It checks an Exception is raised, with the corresponding message.

    Parameters
    ----------
    mock_response_200
        Mocked correct http response, emulating API init
    """

    with pytest.raises(
            Exception,
            match="Empty project name has been passed as argument."):
        assert rubrix.log(name="",
                          records=cast(TextClassificationRecord, None))
Esempio n. 18
0
def test_log_something(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "test-dataset"
    client.delete(f"/api/datasets/{dataset_name}")

    response = rubrix.log(
        name=dataset_name,
        records=rubrix.TextClassificationRecord(
            inputs={"text": "This is a test"}),
    )

    assert response.processed == 1
    assert response.failed == 0

    response = client.post(
        f"/api/datasets/{dataset_name}/TextClassification:search")
    results = TextClassificationSearchResults.from_dict(response.json())
    assert results.total == 1
    assert len(results.records) == 1
    assert results.records[0].inputs["text"] == "This is a test"
Esempio n. 19
0
def test_token_classification_spans(monkeypatch):
    mocking_client(monkeypatch)
    dataset = "test_token_classification_with_consecutive_spans"
    texto = "Esto es una prueba"
    item = rubrix.TokenClassificationRecord(
        text=texto,
        tokens=texto.split(),
        prediction=[("test", 1, 2)],  # Inicio y fin son consecutivos
        prediction_agent="test",
    )
    with pytest.raises(
            Exception,
            match=r"Defined offset \[s\] is a misaligned entity mention"):
        rubrix.log(item, name=dataset)

    item.prediction = [("test", 0, 6)]
    with pytest.raises(
            Exception,
            match=r"Defined offset \[Esto e\] is a misaligned entity mention"):
        rubrix.log(item, name=dataset)

    item.prediction = [("test", 0, 4)]
    rubrix.log(item, name=dataset)
Esempio n. 20
0
def main():

    classifier = loading_model()  # Chached function, loading on top of the app

    # Header
    title, _, subtitle = st.beta_columns((2.5, 0.3, 0.7))

    title.title("How to log your ML experiments and annotations with Rubrix")

    with subtitle:
        st.write("")

    subtitle.subheader("A Web App by [Recognai](https://www.recogn.ai)")

    # First text body
    st.markdown("")  # empty space
    st.markdown(
        """Hey there, welcome! This demo will show you how to keep track of your
        live model predictions using Rubrix.
        """)
    st.markdown(
        """Lets make a quick experiment: an NLP model that guesses which theme a text is talking about.
        We are using a zero-shot classifier based on
        [*SqueezeBERT*](https://huggingface.co/typeform/squeezebert-mnli)""")

    text_input = st.text_area(
        """For example: 'I love to watch cycling competitions!'""")

    confidence_threshold = (
        0.5  # Starting value of the treshold, may be changed with the slider
    )

    if text_input:

        # Making model predictions and storing them into a dataframe
        preds = classifier(
            text_input,
            candidate_labels=CANDIDATE_LABELS,
            hypothesis_template="This text is about {}.",
            multi_class=True,
        )

        df = pd.DataFrame({
            "index": preds["labels"],
            "confidence": [s for s in preds["scores"]],
            "score": [s * 100 for s in preds["scores"]],
        }).set_index("index")

        # Confidence threshold slider, changes the green categories in the graph and the categories shown
        # in the multiselect so users has the classes above the threshold as preanottations
        confidence_threshold = st.slider(
            "We can select a threshold to decide which confidence must be obtained to consider it a prediction.",
            0.0,
            1.0,
            0.5,
            0.01,
        )

        # Predictions according to the threshold
        predictions = populating_predictions(df, confidence_threshold)

        df_table, _, bar_chart = st.beta_columns((1.2, 0.1, 2))

        # Class-Probabilities table
        with df_table:
            # Probabilities field
            st.dataframe(df[["score"]])

        # Class-Probabilities Chart with Confidence
        with bar_chart:
            bar_chart = bar_chart_generator(df, confidence_threshold)
            st.altair_chart(bar_chart, use_container_width=True)

        # Selection of the annotated labels
        selected_labels = st.multiselect(
            label=
            """With the given threshold, these are the categories predicted.
            You can change them, and your final selection will be logged as "user-validated" annotations
             (i.e., ground-truth labels).""",
            options=df.reset_index()["index"].tolist(),
            default=predictions,
        )

        st.markdown(
            """Once you are happy with the input and the categories annotated,
        press the button to log your data into Rubrix.""")

        log_button = center_button(
            "Log {} predictions with {} annotations".format(
                len(df["score"]), len(selected_labels)))

        if log_button:

            # Population of labels
            labels = []
            for _, row in df.reset_index().iterrows():
                labels.append((row["index"], row["confidence"]))

            # Creation of the classification record
            item = rubrix.TextClassificationRecord(
                inputs={"text": text_input},
                prediction=labels,
                prediction_agent="typeform/squeezebert-mnli",
                annotation=selected_labels,
                annotation_agent="streamlit-user",
                multi_label=True,
                event_timestamp=datetime.datetime.now(),
                metadata={"model": "typeform/squeezebert-mnli"})

            dataset_name = "multilabel_text_classification"

            rubrix.log(name=dataset_name, records=item)

            api_url = os.getenv("RUBRIX_API_URL", "http://localhost:6900")
            # Pretty-print of the logged item
            st.markdown(
                f"""Your data has been logged! You can view your dataset in
                 [{api_url}/{dataset_name}]({api_url}/{dataset_name}),
                which has logged this object right below:""")
            st.json(item.dict())

            st.markdown("""
            Logging this predictions into Rubrix can be done with a few commands in your Python scripts."""
                        )

            # By default, Rubrix will connect to http://localhost:6900 with no security.
            st.code(
                """
            import rubrix

            item = rubrix.TextClassificationRecord(
                inputs={"text": text_input},
                prediction=labels,
                prediction_agent="typeform/squeezebert-mnli",
                annotation=selected_labels,
                annotation_agent="streamlit-user",
                multi_label=True,
                event_timestamp=datetime.datetime.now(),
                metadata={"model": "typeform/squeezebert-mnli"}
            )

            rubrix.log(name="experiment_name", records=item)

            """,
                language="python",
            )
Esempio n. 21
0
def test_passing_wrong_iterable_data(monkeypatch):
    mocking_client(monkeypatch)
    dataset_name = "test_log_single_records"
    client.delete(f"/api/datasets/{dataset_name}")
    with pytest.raises(Exception, match="Unknown record type passed"):
        rubrix.log({"a": "010", "b": 100}, name=dataset_name)