コード例 #1
0
ファイル: t5_view.py プロジェクト: whr778/simpletransformers
def get_states(model, session_state=None):
    if session_state:
        setattr(session_state, "max_length", model.args.max_length)
        setattr(session_state, "decoding_algorithm", model.args.do_sample)
        setattr(session_state, "length_penalty", model.args.length_penalty)
        setattr(session_state, "num_beams", model.args.num_beams)
        setattr(session_state, "early_stopping", model.args.early_stopping)
        setattr(session_state, "top_k", model.args.top_k)
        setattr(session_state, "top_p", model.args.top_p)
    else:
        session_state = get(
            max_seq_length=model.args.max_seq_length,
            max_length=model.args.max_length,
            decoding_algorithm="Sampling" if model.args.do_sample else "Beam Search",
            length_penalty=model.args.length_penalty,
            early_stopping=model.args.early_stopping,
            num_beams=model.args.num_beams,
            top_k=model.args.top_k,
            top_p=model.args.top_p,
        )
    model.args.max_seq_length = session_state.max_seq_length
    model.args.max_length = session_state.max_length
    model.args.length_penalty = session_state.length_penalty
    model.args.early_stopping = session_state.early_stopping
    model.args.top_k = session_state.top_k
    model.args.top_p = session_state.top_p

    if session_state.decoding_algorithm == "Sampling":
        model.args.do_sample = True
        model.args.num_beams = None
    elif session_state.decoding_algorithm == "Beam Search":
        model.args.do_sample = False
        model.args.num_beams = session_state.num_beams

    return session_state, model
コード例 #2
0
def ner_viewer(model):
    session_state = get(max_seq_length=model.args.max_seq_length, )
    model.args.max_seq_length = session_state.max_seq_length

    entity_list = model.args.labels_list

    st.sidebar.subheader("Entities")
    entity_checkboxes = {
        entity: st.sidebar.checkbox(entity, value=True)
        for entity in entity_list
    }
    entity_color_map = {
        entity: get_color(i)
        for i, entity in enumerate(entity_list)
    }

    st.sidebar.subheader("Parameters")
    model.args.max_seq_length = st.sidebar.slider(
        "Max Seq Length",
        min_value=1,
        max_value=512,
        value=model.args.max_seq_length)

    st.subheader("Enter text: ")
    input_text = st.text_area("")

    prediction = get_prediction(model, input_text)[0]

    to_write = " ".join([
        format_word(word, entity, entity_checkboxes, entity_color_map)
        for pred in prediction for word, entity in pred.items()
    ])

    st.subheader(f"Predictions")
    st.write(to_write, unsafe_allow_html=True)
コード例 #3
0
ファイル: qa_view.py プロジェクト: phychaos/ConvBert-PyTorch
def qa_viewer(model):
    st.sidebar.subheader("Parameters")
    try:
        session_state, model = get_states(model)
    except AttributeError:
        session_state = get(
            max_seq_length=model.args.max_seq_length,
            max_answer_length=model.args.max_answer_length,
            max_query_length=model.args.max_query_length,
        )
        session_state, model = get_states(model, session_state)

    model.args.max_seq_length = st.sidebar.slider(
        "Max Seq Length", min_value=1, max_value=512, value=model.args.max_seq_length
    )

    model.args.max_answer_length = st.sidebar.slider(
        "Max Answer Length", min_value=1, max_value=512, value=model.args.max_answer_length
    )

    model.args.max_query_length = st.sidebar.slider(
        "Max Query Length", min_value=1, max_value=512, value=model.args.max_query_length
    )

    model.args.n_best_size = st.sidebar.slider("Number of answers to generate", min_value=1, max_value=20)

    st.subheader("Enter context: ")
    context_text = st.text_area("", key="context")

    st.subheader("Enter question: ")
    question_text = st.text_area("", key="question")

    if context_text and question_text:
        answers, probabilities = get_prediction(model, context_text, question_text)

        st.subheader(f"Predictions")
        answers = answers[0]["answer"]

        context_pieces = context_text.split(answers[0])

        if answers[0] != "empty":
            if len(context_pieces) == 2:
                st.write(
                    QA_ANSWER_WRAPPER.format(context_pieces[0], answers[0], context_pieces[-1]), unsafe_allow_html=True
                )
            else:
                st.write(
                    QA_ANSWER_WRAPPER.format(context_pieces[0], answers[0], answers[0].join(context_pieces[1:])),
                    unsafe_allow_html=True,
                )
        else:
            st.write(QA_EMPTY_ANSWER_WRAPPER.format("", answers[0], ""), unsafe_allow_html=True)

        probabilities = probabilities[0]["probability"]

        st.subheader("Confidence")
        output_df = pd.DataFrame({"Answer": answers, "Confidence": probabilities})
        st.dataframe(output_df)

    return model
コード例 #4
0
def get_states(model, session_state=None):
    if session_state:
        setattr(session_state, "max_answer_length", model.args.max_answer_length)
        setattr(session_state, "max_query_length", model.args.max_query_length)
    else:
        session_state = get(
            max_seq_length=model.args.max_seq_length,
            max_answer_length=model.args.max_answer_length,
            max_query_length=model.args.max_query_length,
        )
    model.args.max_seq_length = session_state.max_seq_length
    model.args.max_answer_length = session_state.max_answer_length
    model.args.max_query_length = session_state.max_query_length

    return session_state, model
コード例 #5
0
def get_states(model, session_state=None):
    if session_state:
        setattr(session_state, "sliding_window", model.args.sliding_window)
        setattr(session_state, "stride", model.args.stride)
    else:
        session_state = get(
            max_seq_length=model.args.max_seq_length,
            sliding_window=model.args.sliding_window,
            stride=model.args.stride,
        )
    if session_state.sliding_window == "Enable":
        model.args.sliding_window = True
    else:
        model.args.sliding_window = False

    model.args.max_seq_length = session_state.max_seq_length
    model.args.stride = session_state.stride

    return session_state, model
コード例 #6
0
ファイル: t5_view.py プロジェクト: whr778/simpletransformers
def t5_viewer(model):
    try:
        session_state, model = get_states(model)
    except AttributeError:
        session_state = get(
            max_seq_length=model.args.max_seq_length,
            max_length=model.args.max_length,
            decoding_algorithm=model.args.do_sample,
        )
        session_state, model = get_states(model, session_state)

    st.sidebar.subheader("Parameters")
    model.args.max_seq_length = st.sidebar.slider(
        "Max Seq Length", min_value=1, max_value=512, value=model.args.max_seq_length
    )

    st.sidebar.subheader("Decoding")

    model.args.max_length = st.sidebar.slider(
        "Max Generated Text Length",
        min_value=1,
        max_value=512,
        value=model.args.max_length,
    )

    model.args.length_penalty = st.sidebar.number_input(
        "Length Penalty", value=model.args.length_penalty
    )

    model.args.early_stopping = st.sidebar.radio(
        "Early Stopping", ("True", "False"), index=0 if model.args.early_stopping else 1
    )

    decoding_algorithm = st.sidebar.radio(
        "Decoding Algorithm",
        ("Sampling", "Beam Search"),
        index=0 if model.args.do_sample else 1,
    )

    if decoding_algorithm == "Sampling":
        model.args.do_sample = True
        model.args.num_beams = None
    elif decoding_algorithm == "Beam Search":
        model.args.do_sample = False
        model.args.num_beams = 1

    if model.args.do_sample:
        model.args.top_k = st.sidebar.number_input(
            "Top-k", value=model.args.top_k if model.args.top_k else 50
        )

        model.args.top_p = st.sidebar.slider(
            "Top-p",
            min_value=0.0,
            max_value=1.0,
            value=model.args.top_p if model.args.top_p else 0.95,
        )
    else:
        model.args.num_beams = st.sidebar.number_input(
            "Number of Beams", value=model.args.num_beams
        )

    st.markdown("## Instructions: ")
    st.markdown("The input to a T5 model can be providied in two ways.")
    st.markdown("### Using Prefix")
    st.markdown(
        "If you provide a value for the `prefix`, Simple Viewer will automatically insert `: ` between the `prefix` text and the `input` text."
    )
    st.markdown("### Blank prefix")
    st.markdown(
        "You may also leave the `prefix` blank. In this case, you can provide a prefix and a separator at the start of the `input` text (if your model requires a prefix)."
    )

    st.subheader("Enter prefix: ")
    prefix_text = st.text_input("")

    st.subheader("Enter input text: ")
    input_text = st.text_area("")

    if input_text:
        prediction = get_prediction(model, input_text, prefix_text)[0]

        st.subheader(f"Generated output: ")
        st.write(prediction)
コード例 #7
0
def classification_viewer(model, model_class):
    st.subheader("Enter text: ")
    input_text = st.text_area("")
    st.sidebar.subheader("Parameters")

    if model_class == "ClassificationModel":
        try:
            session_state, model = get_states(model)
        except AttributeError:
            session_state = get(
                max_seq_length=model.args.max_seq_length,
                sliding_window=model.args.sliding_window,
                stride=model.args.stride,
            )
            session_state, model = get_states(model, session_state)

        model.args.max_seq_length = st.sidebar.slider(
            "Max Seq Length",
            min_value=1,
            max_value=512,
            value=model.args.max_seq_length,
        )

        sliding_window = st.sidebar.radio(
            "Sliding Window",
            ("Enable", "Disable"),
            index=0 if model.args.sliding_window else 1,
        )
        if sliding_window == "Enable":
            model.args.sliding_window = True
        else:
            model.args.sliding_window = False

        if model.args.sliding_window:
            model.args.stride = st.sidebar.slider(
                "Stride (Fraction of Max Seq Length)",
                min_value=0.0,
                max_value=1.0,
                value=model.args.stride,
            )
    elif model_class == "MultiLabelClassificationModel":
        try:
            session_state, model = get_states(model)
        except AttributeError:
            session_state = get(max_seq_length=model.args.max_seq_length, )
            session_state, model = get_states(model, session_state)

        model.args.max_seq_length = st.sidebar.slider(
            "Max Seq Length",
            min_value=1,
            max_value=512,
            value=model.args.max_seq_length,
        )

    if input_text:
        prediction, raw_values = get_prediction(model, input_text)
        raw_values = [list(np.squeeze(raw_values))]

        if model.args.sliding_window and isinstance(raw_values[0][0],
                                                    np.ndarray):
            raw_values = np.mean(raw_values, axis=1)

        st.subheader(f"Predictions")
        st.text(f"Predicted label: {prediction[0]}")

        st.subheader(f"Model outputs")
        st.text("Raw values: ")
        try:
            raw_df = pd.DataFrame(
                raw_values,
                columns=[f"Label {label}" for label in model.args.labels_list],
            )
        except Exception:
            raw_df = pd.DataFrame(
                raw_values,
                columns=[
                    f"Label {label}" for label in range(len(raw_values[0]))
                ],
            )
        st.dataframe(raw_df)

        st.text("Probabilities: ")
        try:
            prob_df = pd.DataFrame(
                softmax(raw_values, axis=1),
                columns=[f"Label {label}" for label in model.args.labels_list],
            )
        except Exception:
            prob_df = pd.DataFrame(
                softmax(raw_values, axis=1),
                columns=[f"Label {i}" for i in range(len(raw_values[0]))],
            )
        st.dataframe(prob_df)

    return model