def test_ranking(): fb = FitBert(model_name="distilbert-base-uncased") assert callable(fb.fitb) sentences = [ "When she started talking about her ex-boyfriends, he looked like a ***mask*** out of water", "The boy was warned that if he misbehaved in the class, he would have to pay ***mask***.", "I am surprised that you have ***mask*** patience.", ] options = [ ["frog", "fish"], ["the drummer", "the flutist", "the piper"], ["such a", "so", "such"], ] answers = ["fish", "the piper", "such"] for sentence, option, answer in zip(sentences, options, answers): ranked_options = fb.rank(sentence, option) assert ranked_options[0] == answer, "It should rank options" sentence = "Psychology includes the study of conscious and unconscious phenomena, as well as ***mask*** and thought." options = ["feelings"] answer = "feeling" ranked_options = fb.rank(sentence, options, True) assert ranked_options[ 0] == answer, "It should find and rank related options"
def rank_by_spanbert(phrase_cand, sgs, drug_formal): from transformers import BertForMaskedLM, BertTokenizer bert_tokenizer = BertTokenizer.from_pretrained( 'data/BERT_model_reddit/vocab.txt') bert_model = BertForMaskedLM.from_pretrained( 'data/BERT_model_reddit').to(device) fb = FitBert(model=bert_model, tokenizer=bert_tokenizer, mask_token='[MASK]') MLM_score = defaultdict(float) temp = sgs if len(sgs) < 10 else tqdm(sgs) for sgs_i in temp: if not any(x in sgs_i for x in drug_formal + ['drug']): continue temp = fb.rank_multi( sgs_i, phrase_cand + ['cbd oil', 'hash oil', 'charlie horse', 'lunch money']) scores = [x / max(temp[1]) for x in temp[1]] scores = fb.softmax(torch.tensor(scores).unsqueeze(0)).tolist()[0] top_words = [[temp[0][i], scores[i]] for i in range(min(len(temp[0]), 50))] for j in top_words: if j[0] in string.punctuation: continue if j[0] in stopwords.words('english'): continue if j[0] in drug_formal: continue if j[0] in ['drug', 'drugs']: continue if j[0][: 2] == '##': # the '##' by BERT indicates that is not a word. continue MLM_score[j[0]] += j[1] print(sgs_i) print([x[0] for x in top_words[:20]]) out = sorted(MLM_score, key=lambda x: MLM_score[x], reverse=True) out_tuple = [[x, MLM_score[x]] for x in out] return out, out_tuple
class FBReplacer: def __init__(self, model_name): self.fb = FitBert(model_name=model_name) self.mask = "***mask***" def find_new_word(self, sent, options): """ in a given sentence, replace word at word_span with one of the options""" # print(f"masked={masked}, options={options}") ranked = self.fb.rank(sent, options=options) best_ranked = ranked[0] return best_ranked
from fitbert import FitBert # currently supported models: bert-large-uncased and distilbert-base-uncased # this takes a while and loads a whole big BERT into memory fb = FitBert() """ masked_string = "Why ***mask***, you're looking ***mask*** today!" options = ['buff', 'handsome', 'strong'] ranked_options = fb.rank(masked_string, options=options) print(ranked_options) # >>> ['handsome', 'strong', 'buff'] # or filled_in = fb.fitb(masked_string, options=options) # >>> "Why Bert, you're looking handsome today!" print(filled_in) """ masked_string = "Hello ***mask*** test ***mask*** today!" options1 = [ 'looking', 'catching', 'master', 'handsome', ] options2 = ['rank', 'book', 'strong'] filled_in = fb.rank_multi(masked_string, options=options1) print("rank_multi", filled_in) filled_in1 = fb.new_rank_multi(masked_string, words=options1)
from flask import Flask, render_template, request from predict import FillBert from fitbert import FitBert import requests import re app = Flask(__name__) fb = FitBert() @app.route('/') def home(): return render_template('index.html') @app.route('/result', methods=['POST']) def result(): questions = request.form['question'].replace('\r', '').split('\n') row = {} row['origin'] = request.form['question'].replace('\n', '<br>') row['question'] = questions[0].strip() print(questions) for i in range(1,5): # print(questions[i]) row[str(i)] = re.sub('</?.*?>','',questions[i].replace('(','<').replace(')', '>')).strip() masked_string = row['question'].replace('___', '***mask***') options = [row['1'], row['2'], row['3'], row['4']] row['answer'] = fb.rank_with_prob(masked_string, options)[0][0] return render_template('result.html', row=row) @app.route('/example') def example(): return render_template('example.html') if __name__ == "__main__": app.run(host='0.0.0.0', port=4040, debug=True, threaded=True)
def load(cls, language: str = "en", device: str = "cpu"): if language == "en": fitbert = FitBert( model_name="distilbert-base-uncased", disable_gpu=(device == "cpu") ) return cls(fitbert=fitbert)
# BERT Model instantiation from fitbert import FitBert # RNN Model instantiation from tensorflow import keras # Import other files in this directory import constructor as con from preprocessor import preprocess from bert import fill_in from rnn_text_generator import text_generation # Instantiate flask app and corresponding REST Api app = Flask(__name__) api = Api(app) # Instantiate FitBERT fb = FitBert(model_name="bert-large-uncased", disable_gpu=True) # Instantiate RNN Models model_1 = keras.models.load_model("./models/Shakespeare") model_2 = keras.models.load_model("./models/Michael_Jackson") model_3 = keras.models.load_model("./models/Maid_of_Honor") # Globally save template stages to allow partial calls template_text = "" preprocessed_text = "" current_options = {} # -------------------------------------------------------------------------- # REST Endpoints for BERT requests
def main(): session_state = SessionState.get(user_chat_log=[], bot_reply_log=[]) mode = st.sidebar.radio("Mode: ", options=['Chat', 'TOEIC_part5', 'TOEIC_part6']) # mode = st.radio("Mode: ", options=['Chat', 'TOEIC_part5']) # Reserve space for chatlog chatlog_holder = st.empty() if mode == 'Chat': # Load BlenderbotSmall chạy local # Tải thêm file từ # https://huggingface.co/facebook/blenderbot_small-90M/tree/main # vào folder BlenderbotSmall # Bbot_PATH = './blenderbot-400M-distill' # Bbot_PATH = './blenderbot_small-90M' # Chạy trên server streamlit thì thay path Bbot_PATH = 'facebook/blenderbot-400M-distill' # Bbot_PATH = 'facebook/blenderbot_small-90M' # BbotModel, BbotTokenizer = load_BlenderbotSmall(Bbot_PATH) BbotModel, BbotTokenizer = load_Blenderbot(Bbot_PATH) text = st.text_input("You:") if len(text) != 0: inputs = BbotTokenizer([text], return_tensors='pt') reply_ids = BbotModel.generate(**inputs) bot_reply = BbotTokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0] session_state.user_chat_log.append(text) session_state.bot_reply_log.append(bot_reply) chat_log(session_state.user_chat_log, session_state.bot_reply_log, chatlog_holder) elif mode == 'TOEIC_part5': # Load ELECTRA small chạy local # Tải thêm file từ # https://huggingface.co/google/electra-small-generator/tree/main # vào folder electra-small-generator # ELECTRA_PATH = './electra-small-generator' # Chạy trên server streamlit thì thay path ELECTRA_PATH = 'google/electra-small-generator' ELECTRAmodel, ELECTRAtokenizer = load_ELECTRAsmall(ELECTRA_PATH) fb = FitBert(model=ELECTRAmodel, tokenizer=ELECTRAtokenizer) num_choices = st.sidebar.slider(label="Number of choices", min_value=0, max_value=4) st.sidebar.markdown(""" 0: Fill in the blank \\ 1: Grammatical conjugation \\ 2-4: Sentence completion""") if num_choices == 0: # Fill in the blank question = st.text_input(label='Sentence:') if len(question) != 0: session_state.user_chat_log.append(question) question = question.replace('_', '[MASK]') mlm = pipeline('fill-mask', model=ELECTRAmodel, tokenizer=ELECTRAtokenizer) result = mlm(question)[0]['token_str'].replace(' ', '') session_state.bot_reply_log.append(result) chat_log(session_state.user_chat_log, session_state.bot_reply_log, chatlog_holder) elif num_choices == 1: # Grammatical conjugation question = st.text_input(label='Sentence:') choices = [st.text_input(label='Word needs to conjugate')] if st.button("Conjugate") and len(question) != 0 and choices: session_state.user_chat_log.append(question) question = question.replace('_', '***mask***') bot_choice = fb.fitb(question, options=choices) session_state.bot_reply_log.append(bot_choice) chat_log(session_state.user_chat_log, session_state.bot_reply_log, chatlog_holder) else: # Sentence completion labels = ['A.', 'B.', 'C.', 'D.'] choices = [] question = st.text_input(label='Question:') for i in range(num_choices): choices.append(st.text_input(label=labels[i])) if st.button("Solve") and len(question) != 0 and len( choices[0]) != 0 and len(choices[1]) != 0: session_state.user_chat_log.append(question) question = question.replace('_', '***mask***') bot_choice = fb.rank(question, options=choices)[0] session_state.bot_reply_log.append(bot_choice) chat_log(session_state.user_chat_log, session_state.bot_reply_log, chatlog_holder) elif mode == 'TOEIC_part6': # Load ELECTRA small chạy local # ELECTRA_PATH = './electra-small-generator' # Chạy trên server streamlit thì thay path ELECTRA_PATH = 'google/electra-small-generator' ELECTRAmodel, ELECTRAtokenizer = load_ELECTRAsmall(ELECTRA_PATH) fb = FitBert(model=ELECTRAmodel, tokenizer=ELECTRAtokenizer) question = [] labels = ['Question 1.', 'Question 2.', 'Question 3.', 'Question 4.'] paragraph = st.text_input(label='Paragraph:') for i in range(4): question.append(st.text_input(label=labels[i])) question[i] = question[i].split('/') st.write(question) bot_choices = [] if st.button("Solve") and len(paragraph) != 0: session_state.user_chat_log.append(paragraph) paragraph = paragraph.replace('_', '***mask***') splited_paragraph = paragraph.split() mask_indices = [ i for i, item in enumerate(splited_paragraph) if '***mask***' in item ] if len(paragraph) < len(mask_indices): length = len(paragraph) else: length = len(mask_indices) for i in range(length): mask_idx = mask_indices.pop(0) one_masked_paragraph = ' '.join([ word for idx, word in enumerate(splited_paragraph) if idx not in mask_indices ]) st.write(one_masked_paragraph) bot_choices.append( fb.rank(sent=one_masked_paragraph, options=question[i])[0]) splited_paragraph[mask_idx] = bot_choices[i] session_state.bot_reply_log.append(bot_choices) chat_log(session_state.user_chat_log, session_state.bot_reply_log, chatlog_holder)
from fitbert import FitBert fb = FitBert() masked_string = "Why Bert, you're looking ***mask*** today!" options = ['buff', 'handsome', 'strong'] ranked_options = fb.rank(masked_string, options=options) print(ranked_options)
def __init__(self, model_name): self.fb = FitBert(model_name=model_name) self.mask = "***mask***"