def get_states(model, session_state=None): if session_state: setattr(session_state, "max_length", model.args.max_length) setattr(session_state, "decoding_algorithm", model.args.do_sample) setattr(session_state, "length_penalty", model.args.length_penalty) setattr(session_state, "num_beams", model.args.num_beams) setattr(session_state, "early_stopping", model.args.early_stopping) setattr(session_state, "top_k", model.args.top_k) setattr(session_state, "top_p", model.args.top_p) else: session_state = get( max_seq_length=model.args.max_seq_length, max_length=model.args.max_length, decoding_algorithm="Sampling" if model.args.do_sample else "Beam Search", length_penalty=model.args.length_penalty, early_stopping=model.args.early_stopping, num_beams=model.args.num_beams, top_k=model.args.top_k, top_p=model.args.top_p, ) model.args.max_seq_length = session_state.max_seq_length model.args.max_length = session_state.max_length model.args.length_penalty = session_state.length_penalty model.args.early_stopping = session_state.early_stopping model.args.top_k = session_state.top_k model.args.top_p = session_state.top_p if session_state.decoding_algorithm == "Sampling": model.args.do_sample = True model.args.num_beams = None elif session_state.decoding_algorithm == "Beam Search": model.args.do_sample = False model.args.num_beams = session_state.num_beams return session_state, model
def ner_viewer(model): session_state = get(max_seq_length=model.args.max_seq_length, ) model.args.max_seq_length = session_state.max_seq_length entity_list = model.args.labels_list st.sidebar.subheader("Entities") entity_checkboxes = { entity: st.sidebar.checkbox(entity, value=True) for entity in entity_list } entity_color_map = { entity: get_color(i) for i, entity in enumerate(entity_list) } st.sidebar.subheader("Parameters") model.args.max_seq_length = st.sidebar.slider( "Max Seq Length", min_value=1, max_value=512, value=model.args.max_seq_length) st.subheader("Enter text: ") input_text = st.text_area("") prediction = get_prediction(model, input_text)[0] to_write = " ".join([ format_word(word, entity, entity_checkboxes, entity_color_map) for pred in prediction for word, entity in pred.items() ]) st.subheader(f"Predictions") st.write(to_write, unsafe_allow_html=True)
def qa_viewer(model): st.sidebar.subheader("Parameters") try: session_state, model = get_states(model) except AttributeError: session_state = get( max_seq_length=model.args.max_seq_length, max_answer_length=model.args.max_answer_length, max_query_length=model.args.max_query_length, ) session_state, model = get_states(model, session_state) model.args.max_seq_length = st.sidebar.slider( "Max Seq Length", min_value=1, max_value=512, value=model.args.max_seq_length ) model.args.max_answer_length = st.sidebar.slider( "Max Answer Length", min_value=1, max_value=512, value=model.args.max_answer_length ) model.args.max_query_length = st.sidebar.slider( "Max Query Length", min_value=1, max_value=512, value=model.args.max_query_length ) model.args.n_best_size = st.sidebar.slider("Number of answers to generate", min_value=1, max_value=20) st.subheader("Enter context: ") context_text = st.text_area("", key="context") st.subheader("Enter question: ") question_text = st.text_area("", key="question") if context_text and question_text: answers, probabilities = get_prediction(model, context_text, question_text) st.subheader(f"Predictions") answers = answers[0]["answer"] context_pieces = context_text.split(answers[0]) if answers[0] != "empty": if len(context_pieces) == 2: st.write( QA_ANSWER_WRAPPER.format(context_pieces[0], answers[0], context_pieces[-1]), unsafe_allow_html=True ) else: st.write( QA_ANSWER_WRAPPER.format(context_pieces[0], answers[0], answers[0].join(context_pieces[1:])), unsafe_allow_html=True, ) else: st.write(QA_EMPTY_ANSWER_WRAPPER.format("", answers[0], ""), unsafe_allow_html=True) probabilities = probabilities[0]["probability"] st.subheader("Confidence") output_df = pd.DataFrame({"Answer": answers, "Confidence": probabilities}) st.dataframe(output_df) return model
def get_states(model, session_state=None): if session_state: setattr(session_state, "max_answer_length", model.args.max_answer_length) setattr(session_state, "max_query_length", model.args.max_query_length) else: session_state = get( max_seq_length=model.args.max_seq_length, max_answer_length=model.args.max_answer_length, max_query_length=model.args.max_query_length, ) model.args.max_seq_length = session_state.max_seq_length model.args.max_answer_length = session_state.max_answer_length model.args.max_query_length = session_state.max_query_length return session_state, model
def get_states(model, session_state=None): if session_state: setattr(session_state, "sliding_window", model.args.sliding_window) setattr(session_state, "stride", model.args.stride) else: session_state = get( max_seq_length=model.args.max_seq_length, sliding_window=model.args.sliding_window, stride=model.args.stride, ) if session_state.sliding_window == "Enable": model.args.sliding_window = True else: model.args.sliding_window = False model.args.max_seq_length = session_state.max_seq_length model.args.stride = session_state.stride return session_state, model
def t5_viewer(model): try: session_state, model = get_states(model) except AttributeError: session_state = get( max_seq_length=model.args.max_seq_length, max_length=model.args.max_length, decoding_algorithm=model.args.do_sample, ) session_state, model = get_states(model, session_state) st.sidebar.subheader("Parameters") model.args.max_seq_length = st.sidebar.slider( "Max Seq Length", min_value=1, max_value=512, value=model.args.max_seq_length ) st.sidebar.subheader("Decoding") model.args.max_length = st.sidebar.slider( "Max Generated Text Length", min_value=1, max_value=512, value=model.args.max_length, ) model.args.length_penalty = st.sidebar.number_input( "Length Penalty", value=model.args.length_penalty ) model.args.early_stopping = st.sidebar.radio( "Early Stopping", ("True", "False"), index=0 if model.args.early_stopping else 1 ) decoding_algorithm = st.sidebar.radio( "Decoding Algorithm", ("Sampling", "Beam Search"), index=0 if model.args.do_sample else 1, ) if decoding_algorithm == "Sampling": model.args.do_sample = True model.args.num_beams = None elif decoding_algorithm == "Beam Search": model.args.do_sample = False model.args.num_beams = 1 if model.args.do_sample: model.args.top_k = st.sidebar.number_input( "Top-k", value=model.args.top_k if model.args.top_k else 50 ) model.args.top_p = st.sidebar.slider( "Top-p", min_value=0.0, max_value=1.0, value=model.args.top_p if model.args.top_p else 0.95, ) else: model.args.num_beams = st.sidebar.number_input( "Number of Beams", value=model.args.num_beams ) st.markdown("## Instructions: ") st.markdown("The input to a T5 model can be providied in two ways.") st.markdown("### Using Prefix") st.markdown( "If you provide a value for the `prefix`, Simple Viewer will automatically insert `: ` between the `prefix` text and the `input` text." ) st.markdown("### Blank prefix") st.markdown( "You may also leave the `prefix` blank. In this case, you can provide a prefix and a separator at the start of the `input` text (if your model requires a prefix)." ) st.subheader("Enter prefix: ") prefix_text = st.text_input("") st.subheader("Enter input text: ") input_text = st.text_area("") if input_text: prediction = get_prediction(model, input_text, prefix_text)[0] st.subheader(f"Generated output: ") st.write(prediction)
def classification_viewer(model, model_class): st.subheader("Enter text: ") input_text = st.text_area("") st.sidebar.subheader("Parameters") if model_class == "ClassificationModel": try: session_state, model = get_states(model) except AttributeError: session_state = get( max_seq_length=model.args.max_seq_length, sliding_window=model.args.sliding_window, stride=model.args.stride, ) session_state, model = get_states(model, session_state) model.args.max_seq_length = st.sidebar.slider( "Max Seq Length", min_value=1, max_value=512, value=model.args.max_seq_length, ) sliding_window = st.sidebar.radio( "Sliding Window", ("Enable", "Disable"), index=0 if model.args.sliding_window else 1, ) if sliding_window == "Enable": model.args.sliding_window = True else: model.args.sliding_window = False if model.args.sliding_window: model.args.stride = st.sidebar.slider( "Stride (Fraction of Max Seq Length)", min_value=0.0, max_value=1.0, value=model.args.stride, ) elif model_class == "MultiLabelClassificationModel": try: session_state, model = get_states(model) except AttributeError: session_state = get(max_seq_length=model.args.max_seq_length, ) session_state, model = get_states(model, session_state) model.args.max_seq_length = st.sidebar.slider( "Max Seq Length", min_value=1, max_value=512, value=model.args.max_seq_length, ) if input_text: prediction, raw_values = get_prediction(model, input_text) raw_values = [list(np.squeeze(raw_values))] if model.args.sliding_window and isinstance(raw_values[0][0], np.ndarray): raw_values = np.mean(raw_values, axis=1) st.subheader(f"Predictions") st.text(f"Predicted label: {prediction[0]}") st.subheader(f"Model outputs") st.text("Raw values: ") try: raw_df = pd.DataFrame( raw_values, columns=[f"Label {label}" for label in model.args.labels_list], ) except Exception: raw_df = pd.DataFrame( raw_values, columns=[ f"Label {label}" for label in range(len(raw_values[0])) ], ) st.dataframe(raw_df) st.text("Probabilities: ") try: prob_df = pd.DataFrame( softmax(raw_values, axis=1), columns=[f"Label {label}" for label in model.args.labels_list], ) except Exception: prob_df = pd.DataFrame( softmax(raw_values, axis=1), columns=[f"Label {i}" for i in range(len(raw_values[0]))], ) st.dataframe(prob_df) return model