Exemplo n.º 1
0
def ner_viewer(model):
    session_state = get(max_seq_length=model.args.max_seq_length, )
    model.args.max_seq_length = session_state.max_seq_length

    entity_list = model.args.labels_list

    st.sidebar.subheader("Entities")
    entity_checkboxes = {
        entity: st.sidebar.checkbox(entity, value=True)
        for entity in entity_list
    }
    entity_color_map = {
        entity: get_color(i)
        for i, entity in enumerate(entity_list)
    }

    st.sidebar.subheader("Parameters")
    model.args.max_seq_length = st.sidebar.slider(
        "Max Seq Length",
        min_value=1,
        max_value=512,
        value=model.args.max_seq_length)

    st.subheader("Enter text: ")
    input_text = st.text_area("")

    prediction = get_prediction(model, input_text)[0]

    to_write = " ".join([
        format_word(word, entity, entity_checkboxes, entity_color_map)
        for pred in prediction for word, entity in pred.items()
    ])

    st.subheader(f"Predictions")
    st.write(to_write, unsafe_allow_html=True)
Exemplo n.º 2
0
def get_states(model, session_state=None):
    if session_state:
        setattr(session_state, "max_answer_length",
                model.args.max_answer_length)
        setattr(session_state, "max_query_length", model.args.max_query_length)
    else:
        session_state = get(
            max_seq_length=model.args.max_seq_length,
            max_answer_length=model.args.max_answer_length,
            max_query_length=model.args.max_query_length,
        )
    model.args.max_seq_length = session_state.max_seq_length
    model.args.max_answer_length = session_state.max_answer_length
    model.args.max_query_length = session_state.max_query_length

    return session_state, model
Exemplo n.º 3
0
def get_states(model, session_state=None):
    if session_state:
        setattr(session_state, "sliding_window", model.args.sliding_window)
        setattr(session_state, "stride", model.args.stride)
    else:
        session_state = get(
            max_seq_length=model.args.max_seq_length,
            sliding_window=model.args.sliding_window,
            stride=model.args.stride,
        )
    if session_state.sliding_window == "Enable":
        model.args.sliding_window = True
    else:
        model.args.sliding_window = False

    model.args.max_seq_length = session_state.max_seq_length
    model.args.stride = session_state.stride

    return session_state, model
Exemplo n.º 4
0
def classification_viewer(model, model_class):
    st.subheader("Enter text: ")
    input_text = st.text_area("")
    st.sidebar.subheader("Parameters")

    if model_class == "ClassificationModel":
        try:
            session_state, model = get_states(model)
        except AttributeError:
            session_state = get(
                max_seq_length=model.args.max_seq_length,
                sliding_window=model.args.sliding_window,
                stride=model.args.stride,
            )
            session_state, model = get_states(model, session_state)

        model.args.max_seq_length = st.sidebar.slider(
            "Max Seq Length",
            min_value=1,
            max_value=512,
            value=model.args.max_seq_length)

        sliding_window = st.sidebar.radio(
            "Sliding Window", ("Enable", "Disable"),
            index=0 if model.args.sliding_window else 1)
        if sliding_window == "Enable":
            model.args.sliding_window = True
        else:
            model.args.sliding_window = False

        if model.args.sliding_window:
            model.args.stride = st.sidebar.slider(
                "Stride (Fraction of Max Seq Length)",
                min_value=0.0,
                max_value=1.0,
                value=model.args.stride)
    elif model_class == "MultiLabelClassificationModel":
        try:
            session_state, model = get_states(model)
        except AttributeError:
            session_state = get(max_seq_length=model.args.max_seq_length, )
            session_state, model = get_states(model, session_state)

        model.args.max_seq_length = st.sidebar.slider(
            "Max Seq Length",
            min_value=1,
            max_value=512,
            value=model.args.max_seq_length)

    if input_text:
        prediction, raw_values = get_prediction(model, input_text)
        raw_values = [list(np.squeeze(raw_values))]

        if model.args.sliding_window and isinstance(raw_values[0][0],
                                                    np.ndarray):
            raw_values = np.mean(raw_values, axis=1)

        st.subheader(f"Predictions")
        st.text(f"Predicted label: {prediction[0]}")

        st.subheader(f"Model outputs")
        st.text("Raw values: ")
        try:
            raw_df = pd.DataFrame(
                raw_values,
                columns=[f"Label {label}" for label in model.args.labels_list])
        except Exception:
            raw_df = pd.DataFrame(raw_values,
                                  columns=[
                                      f"Label {label}"
                                      for label in range(len(raw_values[0]))
                                  ])
        st.dataframe(raw_df)

        st.text("Probabilities: ")
        try:
            prob_df = pd.DataFrame(
                softmax(raw_values, axis=1),
                columns=[f"Label {label}" for label in model.args.labels_list])
        except Exception:
            prob_df = pd.DataFrame(
                softmax(raw_values, axis=1),
                columns=[f"Label {i}" for i in range(len(raw_values[0]))])
        st.dataframe(prob_df)

    return model
Exemplo n.º 5
0
def qa_viewer(model):
    st.sidebar.subheader("Parameters")
    try:
        session_state, model = get_states(model)
    except AttributeError:
        session_state = get(
            max_seq_length=model.args.max_seq_length,
            max_answer_length=model.args.max_answer_length,
            max_query_length=model.args.max_query_length,
        )
        session_state, model = get_states(model, session_state)

    model.args.max_seq_length = st.sidebar.slider(
        "Max Seq Length",
        min_value=1,
        max_value=512,
        value=model.args.max_seq_length)

    model.args.max_answer_length = st.sidebar.slider(
        "Max Answer Length",
        min_value=1,
        max_value=512,
        value=model.args.max_answer_length)

    model.args.max_query_length = st.sidebar.slider(
        "Max Query Length",
        min_value=1,
        max_value=512,
        value=model.args.max_query_length)

    model.args.n_best_size = st.sidebar.slider("Number of answers to generate",
                                               min_value=1,
                                               max_value=20)

    st.subheader("Enter context: ")
    context_text = st.text_area("", key="context")

    st.subheader("Enter question: ")
    question_text = st.text_area("", key="question")

    if context_text and question_text:
        answers, probabilities = get_prediction(model, context_text,
                                                question_text)

        st.subheader(f"Predictions")
        answers = answers[0]["answer"]

        context_pieces = context_text.split(answers[0])

        if answers[0] != "empty":
            if len(context_pieces) == 2:
                st.write(QA_ANSWER_WRAPPER.format(context_pieces[0],
                                                  answers[0],
                                                  context_pieces[-1]),
                         unsafe_allow_html=True)
            else:
                st.write(
                    QA_ANSWER_WRAPPER.format(
                        context_pieces[0], answers[0],
                        answers[0].join(context_pieces[1:])),
                    unsafe_allow_html=True,
                )
        else:
            st.write(QA_EMPTY_ANSWER_WRAPPER.format("", answers[0], ""),
                     unsafe_allow_html=True)

        probabilities = probabilities[0]["probability"]

        st.subheader("Confidence")
        output_df = pd.DataFrame({
            "Answer": answers,
            "Confidence": probabilities
        })
        st.dataframe(output_df)

    return model