def predict(): gc.collect() document = request.json["document"] question = request.json["question"] try: out = QA.getAnswer(question, document) gc.collect() return jsonify(out) except Exception as e: print(e) return jsonify({"result": "Model Failed"})
def __init__(self, locs, objs, relations, args): self.graph = nx.Graph() self.graph.add_nodes_from(locs, type='location', fillcolor="yellow", style="filled") self.graph.add_nodes_from(objs, type='object') self.graph.add_edges_from(relations) self.locations = {v for v in locs} self.objects = {v for v in objs} self.edge_labels = {} self.args = args # init GPT-2 with open(args.input_text) as f: self.input_text = f.read() self.model = QA('model/albert-large-squad')
def get_model_api(): model = QA('model') nlp = en_core_web_sm.load() stop_words = set(stopwords.words('english')) def model_api(question): try: question = [w.capitalize() for w in question.split(" ")] question = " ".join(question) doc = nlp(question) search = [] for chunk in doc.noun_chunks: query = chunk.text check_query = nlp(query.lower()) if 'PROPN' in [token.pos_ for token in check_query]: querywords = query.split() query_sentence = [ w.lower() for w in querywords if not w.lower() in stop_words ] query_sentence = ' '.join(query_sentence) search.append(query_sentence) search = [w for w in search if w != ''] print(search) all_content = '' if len(search) != 0: for i in search: i = [w.capitalize() for w in i.split(" ")] i = " ".join(i) for j in wiki_search(i.capitalize()): all_content = all_content + j + '.' answer = model.predict(all_content, question) return answer['answer'] except: return "Sorry, I don't know, can you be a bit more specific OR Wikipedia Server is Busy so can't get Response ." return model_api
def generate(): context = request.json["context"] max_length = 20 do_sample = False if max_length in request.json: max_length = request.json["max_length"] else: max_length = 20 if do_sample in request.json: do_sample = request.json["do_sample"] else: do_sample = False try: out = QA.generateText(context, max_length, do_sample) return jsonify(out) except Exception as e: print(e) return jsonify({"result": "Model Failed"})
class World: def __init__(self, locs, objs, relations, args): self.graph = nx.Graph() self.graph.add_nodes_from(locs, type='location', fillcolor="yellow", style="filled") self.graph.add_nodes_from(objs, type='object') self.graph.add_edges_from(relations) self.locations = {v for v in locs} self.objects = {v for v in objs} self.edge_labels = {} self.args = args # init GPT-2 with open(args.input_text) as f: self.input_text = f.read() self.model = QA('model/albert-large-squad') def is_connected(self): return len(list(nx.connected_components(self.graph))) == 1 def query(self, query, nsamples=10, cutoff=8): return self.model.predictTopK(self.input_text, query, nsamples, cutoff) def generateNeighbors(self, nsamples=100): self.candidates = {} for u in self.graph.nodes: self.candidates[u] = {} if self.graph.nodes[u]['type'] == "location": self.candidates[u]['location'] = self.query( random.choice(loc2loc_templates).format(u), nsamples) self.candidates[u]['object'] = self.query( random.choice(loc2obj_templates).format(u), nsamples) self.candidates[u]['character'] = self.query( random.choice(loc2char_templates).format(u), nsamples) if self.graph.nodes[u]['type'] == "object": self.candidates[u]['location'] = self.query( random.choice(obj2loc_templates).format(u), nsamples) if self.graph.nodes[u]['type'] == "character": self.candidates[u]['location'] = self.query( random.choice(char2loc_templates).format(u), nsamples) def relatedness(self, u, v, type='location'): s = 0 u2v, probs = self.candidates[u][type] if u2v is not None: for c, p in zip(u2v, probs): a = set(c.text.split()).difference(articles) b = set(v.split()).difference(articles) # find best intersect best_intersect = 0 for x in self.graph.nodes: xx = set(x.split()).difference(articles) best_intersect = max(best_intersect, len(a.intersection(xx))) # increment if answer is best match BoW if len(a.intersection(b)) == best_intersect: s += len(a.intersection(b)) * p # naive method # s += len(a.intersection(b)) * p v2u, probs = self.candidates[v]['location'] if v2u is not None: for c, p in zip(v2u, probs): a = set(c.text.split()).difference(articles) b = set(u.split()).difference(articles) # find best intersect best_intersect = 0 for x in self.graph.nodes: xx = set(x.split()).difference(articles) best_intersect = max(best_intersect, len(a.intersection(xx))) # increment if answer is best match BoW if len(a.intersection(b)) == best_intersect: s += len(a.intersection(b)) * p # naive method # s += len(a.intersection(b)) * p return s def extractEntity(self, query, threshold=0.05, cutoff=0): preds, probs = self.query(query, 50, cutoff) if preds is None: return None, 0 for pred, prob in zip(preds, probs): t = pred.text p = prob print('> ', t, p) if len(t) < 1: continue if p > threshold and "MASK" not in t: # find a more minimal candidate if possible for pred, prob in zip(preds, probs): if t != pred.text and pred.text in t and prob > threshold and len( pred.text) > 2: t = pred.text p = prob break t = t.strip(string.punctuation) remove = t # take out leading articles for cleaning words = t.split() if words[0].lower() in articles: remove = " ".join(words[1:]) words[0] = words[0].lower() t = " ".join(words[1:]) print(remove) self.input_text = self.input_text.replace( remove, '[MASK]').replace(' ', ' ').replace(' .', '.') return t, p # else: # find a more minimal candidate if possible # for pred, prob in zip(preds, probs): # if prob > threshold and "MASK" not in pred.text and len(pred.text) > 2 and pred.text in t: # t = pred.text.strip(string.punctuation) # p = prob # self.input_text = self.input_text.replace(t, '[MASK]').replace(' ', ' ').replace(' .', '.') # print(t, p) # return t, p return None, 0 def generate(self): locs = [] objs = [] chars = [] # set thresholds/cutoffs threshold = 0.05 if args.cutoffs == 'fairy': cutoffs = [6.5, -7, -5] # fairy elif args.cutoffs == 'mystery': cutoffs = [3.5, -7.5, -6] # mystery else: cutoffs = [int(i) for i in args.cutoffs.split()] assert len(cutoffs) == 3 # save input text tmp = self.input_text[:] # add chars print("=" * 20 + "\tcharacters\t" + "=" * 20) self.input_text = tmp primer = "Who is somebody in the story?" cutoff = 10 t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff) while t is not None and len(t) > 1: if len(chars) > 1: cutoff = cutoffs[0] chars.append(t) t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff) print("=" * 20 + "\tlocations\t" + "=" * 20) # add locations self.input_text = tmp primer = "Where is the location in the story?" cutoff = 10 t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff) while t is not None and len(t) > 1: locs.append(t) if len(locs) > 1: cutoff = cutoffs[1] t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff) print("=" * 20 + "\tobjects\t\t" + "=" * 20) # add objects self.input_text = tmp primer = "What is an object in the story?" cutoff = 10 t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff) while t is not None and len(t) > 1: if len(objs) > 1: cutoff = cutoffs[2] objs.append(t) t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff) self.input_text = tmp self.graph.add_nodes_from(locs, type='location', fillcolor="yellow", style="filled") self.graph.add_nodes_from(chars, type='character', fillcolor="orange", style="filled") self.graph.add_nodes_from(objs, type='object', fillcolor="white", style="filled") # with open('stats.txt', 'a') as f: # f.write(args.input_text + "\n") # f.write(str(len(locs)) + "\n") # f.write(str(len(chars)) + "\n") # f.write(str(len(objs)) + "\n") self.autocomplete() def autocomplete(self): self.generateNeighbors(self.args.nsamples) print("=" * 20 + "\trelations\t" + "=" * 20) while not self.is_connected(): components = list(nx.connected_components(self.graph)) best = (-1, next(iter(components[0])), next(iter(components[1]))) main = components[0] loc_done = True for c in components[1:]: for v in c: if self.graph.nodes[v]['type'] == 'location': loc_done = False for u in main: if self.graph.nodes[u]['type'] != 'location': continue for c in components[1:]: for v in c: if not loc_done and self.graph.nodes[v][ 'type'] != 'location': continue best = max(best, (self.relatedness( u, v, self.graph.nodes[v]['type']), u, v)) _, u, v = best # attach randomly if empty or specified if _ == 0 or args.random: candidates = [] for c in components[0]: if self.graph.nodes[c]['type'] == 'location': candidates.append(c) u = random.choice(candidates) if self.graph.nodes[u]['type'] == 'location' and self.graph.nodes[ v]['type'] == 'location': type = "connected to" else: type = "located in" print("{} {} {}".format(v, type, u)) self.graph.add_edge(v, u, label=type) self.edge_labels[(v, u)] = type def export(self, filename="graph.dot"): nx.nx_pydot.write_dot(self.graph, filename) nx.write_gml(self.graph, "graph.gml", stringizer=None) def draw(self, filename="./graph.svg"): self.export() if args.write_sfdp: cmd = "sfdp -x -Goverlap=False -Tsvg graph.dot".format(filename) returned_value = subprocess.check_output(cmd, shell=True) with open(filename, 'wb') as f: f.write(returned_value) cmd = "inkscape -z -e {}.png {}.svg".format( filename[:-4], filename[:-4]) returned_value = subprocess.check_output(cmd, shell=True) else: nx.draw(self.graph, with_labels=True) plt.savefig(filename[:-4] + '.png')
from flask import Flask, request, jsonify from flask_cors import CORS from bert import QA app = Flask(__name__) CORS(app) modelName = "mrm8488/bert-small-finetuned-squadv2" model = QA(modelName) @app.route("/predict", methods=['POST']) def predict(): doc = request.json["document"] q = request.json["question"] try: out = model.predict(doc, q) return jsonify({"result": out}) except Exception as e: print(e) return jsonify({"result": "Model Failed"}) if __name__ == "__main__": app.run('0.0.0.0', port=8000)
from flask import Flask, request, jsonify from flask_cors import CORS from bert import QA app = Flask(__name__) CORS(app) model = QA("bert-large-cased-whole-word-masking-finetuned-squad") @app.route("/predict", methods=['POST']) def predict(): doc = request.json["document"] q = request.json["question"] try: out = model.predict(doc, q) return jsonify({"result": out}) except Exception as e: print(e) return jsonify({"result": "Model Failed"}) if __name__ == "__main__": app.run('0.0.0.0', port=8000)
import os from flask import Flask, request, jsonify from flask_cors import CORS from dotenv import load_dotenv from bert import QA app = Flask(__name__) CORS(app) load_dotenv() model = QA(os.getenv("OUTPUT_DIR")) @app.route("/predict", methods=['POST']) def predict(): doc = request.json["document"] q = request.json["question"] try: out = model.predict(doc, q) return jsonify({"result": out}) except Exception as e: app.logger.warning(e) return jsonify({"result": "Model Failed"}) if __name__ == "__main__": app.run('0.0.0.0', port=8000, debug=True)
from bert import QA model = QA('model') doc = "Victoria has a written constitution enacted in 1975, but based on the 1855 colonial constitution, passed by the United Kingdom Parliament as the Victoria Constitution Act 1855, which establishes the Parliament as the state's law-making body for matters coming under state responsibility. The Victorian Constitution can be amended by the Parliament of Victoria, except for certain 'entrenched' provisions that require either an absolute majority in both houses, a three-fifths majority in both houses, or the approval of the Victorian people in a referendum, depending on the provision." doc = "According to the Indian census of 2001, there were 30,803,747 speakers of Malayalam in Kerala, making up 93.2% of the total number of Malayalam speakers in India, and 96.7% of the total population of the state. There were a further 701,673 (2.1% of the total number) in Karnataka, 557,705 (1.7%) in Tamil Nadu and 406,358 (1.2%) in Maharashtra. The number of Malayalam speakers in Lakshadweep is 51,100, which is only 0.15% of the total number, but is as much as about 84% of the population of Lakshadweep. In all, Malayalis made up 3.22% of the total Indian population in 2001. Of the total 33,066,392 Malayalam speakers in India in 2001, 33,015,420 spoke the standard dialects, 19,643 spoke the Yerava dialect and 31,329 spoke non-standard regional variations like Eranadan. As per the 1991 census data, 28.85% of all Malayalam speakers in India spoke a second language and 19.64% of the total knew three or more languages. Large numbers of Malayalis have settled in Bangalore, Mangalore, Delhi, Coimbatore, Hyderabad, Mumbai (Bombay), Ahmedabad, Pune, and Chennai (Madras). A large number of Malayalis have also emigrated to the Middle East, the United States, and Europe. Accessed November 22, 2014.</ref> including a large number of professionals. There were 7,093 Malayalam speakers in Australia in 2006. The 2001 Canadian census reported 7,070 people who listed Malayalam as their mother tongue, mostly in the Greater Toronto Area and Southern Ontario. In 2010, the Census of Population of Singapore reported that there were 26,348 Malayalees in Singapore. The 2006 New Zealand census reported 2,139 speakers. 134 Malayalam speaking households were reported in 1956 in Fiji. There is also a considerable Malayali population in the Persian Gulf regions, especially in Bahrain, Muscat, Doha, Dubai, Abu Dhabi, Kuwait and European region mainly in London. World Malayalee Council, the organisation working with the Malayali diaspora across the Globe has embarked upon a project for making a data bank of the diaspora. CANNOTANSWER" q = 'When did Victoria enact its constitution?' q = "What other languages are spoken there?" answer = model.predict(doc, q) print(answer['answer']) # 1975 # dict_keys(['answer', 'start', 'end', 'confidence', 'document']))
from word2vec_repo.DocSim import DocSim from nltk.corpus import stopwords import nltk import os import numpy as np #some of the code are from https://stackoverflow.com/a/8897648 model_path = 'word2vec_repo/model.bin' stopwords = stopwords.words('english') model = KeyedVectors.load(model_path) ds = DocSim(model, stopwords=stopwords) from bert import QA model = QA('BERTap/model') #fetch top 3 doc using tf-idf weighting method(for large corpus set it to 10) def get_doc(qu): final_docs = matching_score(10, qu) #print(final_docs) #print(final_docs) nd = len(final_docs) if len(final_docs) > 3: nd = 3 answer_doc = [] for i in range(nd): file = open(final_docs[i], 'r', encoding='cp1250') text = file.read().strip() file.close()
from bert import QA model = QA('model') story = input('Enter story: ') flag = 'y' while flag == 'y' or flag == 'yes': ques = input('Enter question: ') answer1 = model.predict(story, ques) print('Answer is ', answer1['answer']) flag = input('Ask more: y/n ') ''' doc = "Hey! I am Mehak. I love dancing. I live in Punjab. I have one brother and one sister. My father is a businessman. My mother is a teacher. I love my family." q1 = 'What is my name?' q2 = 'What are my hobbies?' q3 = 'How many siblings do I have?' q4 = 'What is the occupation of my mother?' q5 = 'What is the occupation of my father?' answer1 = model.predict(doc,q1) answer2 = model.predict(doc,q2) answer3 = model.predict(doc,q3) answer4 = model.predict(doc,q4) answer5 = model.predict(doc,q5) print(answer1['answer']) print(answer2['answer']) print(answer3['answer']) print(answer4['answer']) print(answer5['answer'])
import time import nltk from bert import QA model = QA('model') doc = open("data.txt").read() q = "Which department filed an antitrust lawsuit against Microsoft in 1998 ?" s = "" start_time = time.time() words = nltk.word_tokenize(doc) print("Words in the document are "+ str(len(words))+" words") if(len(words)>512): answer = model.predict(doc,q,150) else: answer = model.predict(doc,q,len(words)) if(len(answer)==0): s+="No" else: s+="Yes, " for i in range(len(answer)): t = answer[i] s+= t['answer'] + ", " # 1975 print(q) print("Final answer: ",s) print("Time taken in seconds: " , (time.time() - start_time))
def load_model(_): model = QA('model') return model
from nltk.tag import StanfordPOSTagger from nltk.tag import StanfordNERTagger from sklearn import linear_model from sklearn import svm from sklearn.metrics import fbeta_score, accuracy_score from scipy.sparse import hstack from nltk.stem.porter import PorterStemmer from nltk.stem.snowball import SnowballStemmer from nltk.stem.wordnet import WordNetLemmatizer import requests from bert import QA model = QA('model') context = ' Bits Pilani is a private institute of higher education and a deemed university under Section 3 of the UGC Act 1956. \ The institute was established in its present form in 1964. It is established across 4 campuses and has 15 academic departments. \ Pilani is located 220 kilometres far from Delhi in Rajasthan. We can reach bits pilani via train or bus from delhi. \ Bits Pilani has its campuses in Pilani , Goa , Hyderabad , Dubai. There are multiple scholarships available at BITS namely Merit Scholarships, Merit Cum Need Scholarships and BITSAA Scholarships. \ BITS Model United Nations Conference (BITSMUN) is one of the largest MUN conferences in the country. BITS conducts the All-India computerized entrance examination, BITSAT (BITS Admission Test). \ Admission is merit-based, as assessed by the BITSAT examination. \ We can reach bits pilani through bus or train from delhi or jaipur. \ Mr. Ashoke Kumar Sarkar is the director of Bits Pilani, pilani campus. \ Founder of Bits pilani was Ghanshyam Das Birla.' def get_answer(context: str, ques: str): answer = model.predict(context, ques) return answer['answer']
from flask import Flask, request, jsonify import flask_cors import os from bert import QA from gcp import GCP os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "fyp-qa-eb7816dfb87e.json" os.environ["CUDA_VISIBLE_DEVICES"] = "" app = Flask(__name__, static_url_path='/static') flask_cors.CORS(app) model = QA() gcp = GCP() @app.route("/file/raw", methods=['POST']) def file_upload_raw(): try: filename = request.json["filename"] text = request.json["text"] if filename == '' or text == '': return 'No file text or name empty .', 400 doc_id = gcp.upload_raw(filename, text) return jsonify({'success': True, 'id': doc_id}) except Exception as e: print(e)
import time basedir = os.path.abspath(os.path.dirname(__file__)) app = Flask(__name__) SECRET_KEY = os.urandom(32) app.config['SECRET_KEY'] = SECRET_KEY MODEL_PATH=os.path.join(basedir, 'model') DATA_PATH=os.path.join(basedir, 'dataUIT') data=load_data(DATA_PATH) stopwords = set(open(basedir+'\stopwords.txt',encoding="utf-8").read().split(' ')[:-1]) #load model print("Load model...") start = time.time() model=QA(MODEL_PATH) #path to model end = time.time() print("time load model: "+str(round((end - start),2))) #Building index print('Building index...') start = time.time() data_standard=standardize_data(data,stopwords) vect = TfidfVectorizer(min_df=1, max_df=0.8,max_features=5000,sublinear_tf=True,ngram_range=(1,3)) vect.fit(data_standard) end = time.time() print("Time building index: "+str(round((end - start),2))) @app.route('/', methods=['GET','POST'])
# see original BERT-SQuAD : #https://github.com/kamalkraj/BERT-SQuAD ############################################################# #Copyright 2019 Pierre Rouarch # License GPL V3 ############################################################# myKeyword = "When Abraham Lincoln died and how?" from bert import QA n_best_size = 20 #list of pretrained model #https://huggingface.co/transformers/pretrained_models.html #!!!!! instantiate a BERT model fine tuned on SQuAD #Choose your model #'bert-large-uncased-whole-word-masking-finetuned-squad' #'bert-large-cased-whole-word-masking-finetuned-squad' model = QA('bert-large-uncased-whole-word-masking-finetuned-squad', n_best_size) #import needed libraries import pandas as pd import numpy as np #pip install google #to install Google Search by Mario Vilas see #https://python-googlesearch.readthedocs.io/en/latest/ import googlesearch #Scrap serps #to randomize pause import random import time #to calcute page time downlod from datetime import date import sys #for sys variables import requests #to read urls contents from bs4 import BeautifulSoup #to decode html
from bert import QA import pandas as pd import re import csv from datetime import datetime from utils import * time_str = datetime.now().strftime('%Y%m%d%H%M%S') bank_filepath = 'itembank.csv' model = QA('model') # Load the database content_code = 'A' df = pd.read_csv(bank_filepath) # Data cleanup df = df.replace("[\[].*?[\]]", "", regex=True) # remove all bracket stylings df = df.replace("\(select.*?\)", "", regex=True) # remove all select prompts df = df.replace('\s+', ' ', regex=True) # truncate all multiple spaces # Data filters context = (df['ContentCode'] == content_code) \ & (~df['Stem'].str.contains('arrow|points to|the line|line #|1. |video|label|cursor|click|\d{3}') # Attempt to filter out questions that refer to images & (df['FullKey'] > 'D') # Filter out questions with more than 4 options & (df['FullKey'].str.len() == 1)) # Filter out questions with multiple answers, will not work with model prediction # Inits iterator = 0 length = df[context].shape[0] total = int((length * (length + 1)) / 2)
from bert import QA import os model = QA('model') def read_all_data(word2idx, max_words, max_sentences): # stories[story_ind] = [[sentence1], [sentence2], ..., [sentenceN]] # questions[question_ind] = {'question': [question], 'answer': [answer], 'story_index': #, 'sentence_index': #} stories = dict() questions = dict() if len(word2idx) == 0: word2idx['<null>'] = 0 data_dir = "./tasksv11/en-valid" for babi_task in range(1,21): fname = '{}/qa{}_test.txt'.format(data_dir, babi_task) if os.path.isfile(fname): with open(fname) as f: lines = f.readlines() else: raise Exception("[!] Data {file} not found".format(file=fname)) for line in lines: words = line.split() max_words = max(max_words, len(words)) # Determine whether the line indicates the start of a new story if words[0] == '1': story_ind = len(stories) sentence_ind = 0
import gradio as gr import os, sys sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__), "utils")) from bert import QA model = QA('bert-large-uncased-whole-word-masking-finetuned-squad') def qa_func(paragraph, question): return model.predict(paragraph, question)["answer"] iface = gr.Interface(qa_func, [ gr.inputs.Textbox(lines=7, label="Context", default="Victoria has a written constitution enacted in 1975, but based on the 1855 colonial constitution, passed by the United Kingdom Parliament as the Victoria Constitution Act 1855, which establishes the Parliament as the state's law-making body for matters coming under state responsibility. The Victorian Constitution can be amended by the Parliament of Victoria, except for certain 'entrenched' provisions that require either an absolute majority in both houses, a three-fifths majority in both houses, or the approval of the Victorian people in a referendum, depending on the provision."), gr.inputs.Textbox(lines=1, label="Question", default="When did Victoria enact its constitution?"), ], gr.outputs.Textbox(label="Answer")) if __name__ == "__main__": iface.launch()
from flask import Flask,request,jsonify from flask_cors import CORS from bert import QA app = Flask(__name__) CORS(app) #model = QA("model") model = QA("/home/k3ijo/bert/nlp_model/multi_cased_L-12_H-768_A-12/") @app.route("/predict",methods=['POST']) def predict(): doc = request.json["document"] q = request.json["question"] try: out = model.predict(doc,q) return jsonify({"result":out}) except Exception as e: print(e) return jsonify({"result":"Model Failed"}) if __name__ == "__main__": app.run('0.0.0.0',port=8000)
def test(use_jit=False, fp16=False, onnx_runtime=False, export_onnx=False, tf_onnx=False, tf_version=False, vsl='none', min_batch=0, max_batch=1, num_predicts=300): document1 = 'Two partially reusable launch systems were developed, the Space Shuttle and Falcon 9. ' \ 'The Space Shuttle was partially reusable: the orbiter (which included the Space Shuttle ' \ 'main engines and the Orbital Maneuvering System engines), and the two solid rocket boosters ' \ 'were reused after several months of refitting work for each launch. The external tank was ' \ 'discarded after each flight. and the two solid rocket boosters were reused after several ' \ 'months of refitting work for each launch. The external tank was discarded after each flight.' document2 = 'This contrasts with expendable launch systems, where each launch vehicle is launched once ' \ 'and then discarded. No completely reusable orbital launch system has ever been created.' document3 = 'A reusable launch system (RLS, or reusable launch vehicle, RLV) is a launch system which is ' \ 'capable of launching a payload into space more than once. This contrasts with expendable ' \ 'launch systems, where each launch vehicle is launched once and then discarded. No completely ' \ 'reusable orbital launch system has ever been created.' question = 'How many partially reusable launch systems were developed?' # passages = [document1, document2, document3, document1, document2, document3, document1, document2, document3] # passages = [document1, document2, document3] # passages = [document1] if tf_onnx or tf_version: from multiprocessing import Pool convert_onnx_to_tf = False if tf_onnx and convert_onnx_to_tf: onnx_model = onnx.load(ONNX_PATH) # prepare tf representation tf_exp = onnx_tf.backend.prepare(onnx_model) # export the model tf_exp.export_graph(ONNX_TF_PB_PATH) onnx_pb_graph = tf.Graph() with onnx_pb_graph.as_default(): tf_pb_path = ONNX_TF_PB_PATH if tf_onnx else TF_PB_PATH onnx_pb_graph_def = tf.GraphDef() with tf.gfile.GFile(tf_pb_path, 'rb') as fid: serialized_graph = fid.read() onnx_pb_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(onnx_pb_graph_def, name='') config = tf.ConfigProto() config.gpu_options.allow_growth = True if use_jit: # config.gpu_options.per_process_gpu_memory_fraction = 0.5 config.log_device_placement = False config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 with tf.Session(config=config) as sess: # INFERENCE using session.run model = QA(MODEL_PATH, use_jit=use_jit, fp16=fp16, onnx=onnx_runtime, sess=sess, vsl=vsl, tf_onnx=tf_onnx) print('-- BENCHMARKING: JIT={} | FP16={} | ONNX_RUNTIME={} | ' 'TF_ONNX_VERSION={} | TF_VERSION={} | EXACT_VSL={} --'. format(use_jit, fp16, onnx_runtime, tf_onnx, tf_version, vsl)) for passage_batch in range(min_batch, max_batch): passage_batch = pow(3, passage_batch - 1) if passage_batch < 1: passages = [document1] else: passages = [] for i in range(passage_batch): passages.append(document1) passages.append(document2) passages.append(document3) if max_batch > 2: num_predicts = 50 time_taken, rps = measure_inference( model, passages, question, num_predicts) # print('Time taken for test: {} s'.format(time_taken)) print('RPS: {}'.format(rps)) sess.close() del model, sess else: model = QA(MODEL_PATH, use_jit=use_jit, fp16=fp16, onnx=onnx_runtime, export_onnx=export_onnx, vsl=vsl, onnx_path=ONNX_PATH) if not export_onnx: print( '-- BENCHMARKING: JIT={} | FP16={} | ONNX_RUNTIME={} | ' 'TF_ONNX_VERSION={} | TF_VERSION={} | EXACT_VSL={} --'.format( use_jit, fp16, onnx_runtime, tf_onnx, tf_version, vsl)) for passage_batch in range(min_batch, max_batch): passage_batch = pow(3, passage_batch - 1) if passage_batch < 1: passages = [document1] else: passages = [] for i in range(passage_batch): passages.append(document1) passages.append(document2) passages.append(document3) if max_batch > 2: num_predicts = 50 time_taken, rps = measure_inference(model, passages, question, num_predicts) # print('Time taken for test: {} s'.format(time_taken)) print('RPS: {}'.format(rps)) del model torch.cuda.empty_cache()
""" Make sure bert.py exists in the same directory as this script and the model files downloaded form the dropbox link is placed under the model folder """ from bert import QA model = QA('model') """ Now, let us implement comprehending a passage from google with our existing zero shot model """ passage = 'There was a princess called Maggie, she was tall and as white as snow. She had red lips like a \ rose and her hair was brown. She had light blue eyes and was very nice and kind. \ She was in love with a bricklayer called Kevin. He was tall, had brown hair and tanned skin. \ He was strong, had dark eyes, was kind and he was never angry' ques = 'What is the name of the princess?' """ Generally the predict method gives out a dictionary with other values as well, but since our interest lies with answer we subset the same. """ answer = model.predict(passage, ques)['answer'] print("{} : {}".format(ques, answer)) """ Let us try few more questions to see the edge case scenarios to understand where the model might break """ answer = model.predict(passage, 'Who was Maggie?')['answer'] print("{} : {}".format('Who was Maggie?', answer))
from flask import Flask, render_template, request, send_file import warnings warnings.filterwarnings("ignore") from bert import QA DEBUG = False app = Flask(__name__) app.config.from_object(__name__) model = QA("model") @app.route('/') def single_people_code(): return render_template('index.html') @app.route("/",methods=['POST','GET']) def predict(): if request.method == 'POST': form = request.form results = request.form doc = results['passage'] q = results['question'] try: out = model.predict(doc,q)
import sys sys.path.append('./BERT-SQuAD') import json import torch from bert import QA def predict(doc, q): answer = model.predict(doc, q) content = json.dumps(answer, separators=(',', ':')) print(content, flush=True) print("Loading model!", flush=True) model = QA('model') print("Model loading complete!", flush=True, end='') model.predict("Whales are a kind of animal called a mammal.", "What is a whale?")