예제 #1
0
def predict():
    gc.collect()
    document = request.json["document"]
    question = request.json["question"]
    try:
        out = QA.getAnswer(question, document)
        gc.collect()
        return jsonify(out)
    except Exception as e:
        print(e)
        return jsonify({"result": "Model Failed"})
    def __init__(self, locs, objs, relations, args):
        self.graph = nx.Graph()
        self.graph.add_nodes_from(locs,
                                  type='location',
                                  fillcolor="yellow",
                                  style="filled")
        self.graph.add_nodes_from(objs, type='object')
        self.graph.add_edges_from(relations)

        self.locations = {v for v in locs}
        self.objects = {v for v in objs}
        self.edge_labels = {}

        self.args = args

        # init GPT-2
        with open(args.input_text) as f:
            self.input_text = f.read()

        self.model = QA('model/albert-large-squad')
예제 #3
0
def get_model_api():

    model = QA('model')
    nlp = en_core_web_sm.load()
    stop_words = set(stopwords.words('english'))

    def model_api(question):
        try:
            question = [w.capitalize() for w in question.split(" ")]
            question = " ".join(question)
            doc = nlp(question)
            search = []

            for chunk in doc.noun_chunks:
                query = chunk.text
                check_query = nlp(query.lower())

                if 'PROPN' in [token.pos_ for token in check_query]:
                    querywords = query.split()
                    query_sentence = [
                        w.lower() for w in querywords
                        if not w.lower() in stop_words
                    ]
                    query_sentence = ' '.join(query_sentence)
                    search.append(query_sentence)

            search = [w for w in search if w != '']
            print(search)

            all_content = ''
            if len(search) != 0:
                for i in search:
                    i = [w.capitalize() for w in i.split(" ")]
                    i = " ".join(i)

                    for j in wiki_search(i.capitalize()):
                        all_content = all_content + j + '.'

            answer = model.predict(all_content, question)

            return answer['answer']

        except:
            return "Sorry, I don't know, can you be a bit more specific OR Wikipedia Server is Busy so can't get Response ."

    return model_api
예제 #4
0
def generate():
    context = request.json["context"]

    max_length = 20
    do_sample = False
    if max_length in request.json:
        max_length = request.json["max_length"]
    else:
        max_length = 20
    if do_sample in request.json:
        do_sample = request.json["do_sample"]
    else:
        do_sample = False
    try:
        out = QA.generateText(context, max_length, do_sample)
        return jsonify(out)
    except Exception as e:
        print(e)
        return jsonify({"result": "Model Failed"})
class World:
    def __init__(self, locs, objs, relations, args):
        self.graph = nx.Graph()
        self.graph.add_nodes_from(locs,
                                  type='location',
                                  fillcolor="yellow",
                                  style="filled")
        self.graph.add_nodes_from(objs, type='object')
        self.graph.add_edges_from(relations)

        self.locations = {v for v in locs}
        self.objects = {v for v in objs}
        self.edge_labels = {}

        self.args = args

        # init GPT-2
        with open(args.input_text) as f:
            self.input_text = f.read()

        self.model = QA('model/albert-large-squad')

    def is_connected(self):
        return len(list(nx.connected_components(self.graph))) == 1

    def query(self, query, nsamples=10, cutoff=8):

        return self.model.predictTopK(self.input_text, query, nsamples, cutoff)

    def generateNeighbors(self, nsamples=100):
        self.candidates = {}
        for u in self.graph.nodes:
            self.candidates[u] = {}
            if self.graph.nodes[u]['type'] == "location":
                self.candidates[u]['location'] = self.query(
                    random.choice(loc2loc_templates).format(u), nsamples)
                self.candidates[u]['object'] = self.query(
                    random.choice(loc2obj_templates).format(u), nsamples)
                self.candidates[u]['character'] = self.query(
                    random.choice(loc2char_templates).format(u), nsamples)
            if self.graph.nodes[u]['type'] == "object":
                self.candidates[u]['location'] = self.query(
                    random.choice(obj2loc_templates).format(u), nsamples)
            if self.graph.nodes[u]['type'] == "character":
                self.candidates[u]['location'] = self.query(
                    random.choice(char2loc_templates).format(u), nsamples)

    def relatedness(self, u, v, type='location'):

        s = 0
        u2v, probs = self.candidates[u][type]

        if u2v is not None:
            for c, p in zip(u2v, probs):
                a = set(c.text.split()).difference(articles)
                b = set(v.split()).difference(articles)

                # find best intersect
                best_intersect = 0
                for x in self.graph.nodes:
                    xx = set(x.split()).difference(articles)
                    best_intersect = max(best_intersect,
                                         len(a.intersection(xx)))

                # increment if answer is best match BoW
                if len(a.intersection(b)) == best_intersect:
                    s += len(a.intersection(b)) * p

                # naive method
                # s += len(a.intersection(b)) * p

        v2u, probs = self.candidates[v]['location']

        if v2u is not None:
            for c, p in zip(v2u, probs):
                a = set(c.text.split()).difference(articles)
                b = set(u.split()).difference(articles)

                # find best intersect
                best_intersect = 0
                for x in self.graph.nodes:
                    xx = set(x.split()).difference(articles)
                    best_intersect = max(best_intersect,
                                         len(a.intersection(xx)))

                # increment if answer is best match BoW
                if len(a.intersection(b)) == best_intersect:
                    s += len(a.intersection(b)) * p

                # naive method
                # s += len(a.intersection(b)) * p

        return s

    def extractEntity(self, query, threshold=0.05, cutoff=0):

        preds, probs = self.query(query, 50, cutoff)

        if preds is None:
            return None, 0

        for pred, prob in zip(preds, probs):
            t = pred.text
            p = prob
            print('> ', t, p)
            if len(t) < 1:
                continue
            if p > threshold and "MASK" not in t:

                # find a more minimal candidate if possible
                for pred, prob in zip(preds, probs):
                    if t != pred.text and pred.text in t and prob > threshold and len(
                            pred.text) > 2:
                        t = pred.text
                        p = prob
                        break

                t = t.strip(string.punctuation)
                remove = t

                # take out leading articles for cleaning
                words = t.split()
                if words[0].lower() in articles:
                    remove = " ".join(words[1:])
                    words[0] = words[0].lower()
                    t = " ".join(words[1:])
                print(remove)

                self.input_text = self.input_text.replace(
                    remove, '[MASK]').replace('  ', ' ').replace(' .', '.')
                return t, p
            # else:
            # find a more minimal candidate if possible
            # for pred, prob in zip(preds, probs):
            #     if prob > threshold and "MASK" not in pred.text and len(pred.text) > 2 and pred.text in t:
            #         t = pred.text.strip(string.punctuation)
            #         p = prob
            #         self.input_text = self.input_text.replace(t, '[MASK]').replace('  ', ' ').replace(' .', '.')
            #         print(t, p)
            #         return t, p

        return None, 0

    def generate(self):

        locs = []
        objs = []
        chars = []

        # set thresholds/cutoffs
        threshold = 0.05

        if args.cutoffs == 'fairy':
            cutoffs = [6.5, -7, -5]  # fairy
        elif args.cutoffs == 'mystery':
            cutoffs = [3.5, -7.5, -6]  # mystery
        else:
            cutoffs = [int(i) for i in args.cutoffs.split()]
            assert len(cutoffs) == 3

        # save input text
        tmp = self.input_text[:]

        # add chars
        print("=" * 20 + "\tcharacters\t" + "=" * 20)
        self.input_text = tmp
        primer = "Who is somebody in the story?"
        cutoff = 10
        t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
        while t is not None and len(t) > 1:
            if len(chars) > 1:
                cutoff = cutoffs[0]
            chars.append(t)
            t, p = self.extractEntity(primer,
                                      threshold=threshold,
                                      cutoff=cutoff)

        print("=" * 20 + "\tlocations\t" + "=" * 20)

        # add locations
        self.input_text = tmp
        primer = "Where is the location in the story?"
        cutoff = 10
        t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
        while t is not None and len(t) > 1:
            locs.append(t)

            if len(locs) > 1:
                cutoff = cutoffs[1]

            t, p = self.extractEntity(primer,
                                      threshold=threshold,
                                      cutoff=cutoff)

        print("=" * 20 + "\tobjects\t\t" + "=" * 20)

        # add objects
        self.input_text = tmp
        primer = "What is an object in the story?"
        cutoff = 10
        t, p = self.extractEntity(primer, threshold=threshold, cutoff=cutoff)
        while t is not None and len(t) > 1:
            if len(objs) > 1:
                cutoff = cutoffs[2]
            objs.append(t)
            t, p = self.extractEntity(primer,
                                      threshold=threshold,
                                      cutoff=cutoff)
        self.input_text = tmp

        self.graph.add_nodes_from(locs,
                                  type='location',
                                  fillcolor="yellow",
                                  style="filled")
        self.graph.add_nodes_from(chars,
                                  type='character',
                                  fillcolor="orange",
                                  style="filled")
        self.graph.add_nodes_from(objs,
                                  type='object',
                                  fillcolor="white",
                                  style="filled")

        # with open('stats.txt', 'a') as f:
        # f.write(args.input_text + "\n")
        # f.write(str(len(locs)) + "\n")
        # f.write(str(len(chars)) + "\n")
        # f.write(str(len(objs)) + "\n")
        self.autocomplete()

    def autocomplete(self):

        self.generateNeighbors(self.args.nsamples)

        print("=" * 20 + "\trelations\t" + "=" * 20)
        while not self.is_connected():
            components = list(nx.connected_components(self.graph))
            best = (-1, next(iter(components[0])), next(iter(components[1])))

            main = components[0]

            loc_done = True
            for c in components[1:]:
                for v in c:
                    if self.graph.nodes[v]['type'] == 'location':
                        loc_done = False

            for u in main:
                if self.graph.nodes[u]['type'] != 'location':
                    continue

                for c in components[1:]:
                    for v in c:
                        if not loc_done and self.graph.nodes[v][
                                'type'] != 'location':
                            continue
                        best = max(best, (self.relatedness(
                            u, v, self.graph.nodes[v]['type']), u, v))

            _, u, v = best

            # attach randomly if empty or specified
            if _ == 0 or args.random:
                candidates = []
                for c in components[0]:
                    if self.graph.nodes[c]['type'] == 'location':
                        candidates.append(c)
                u = random.choice(candidates)
            if self.graph.nodes[u]['type'] == 'location' and self.graph.nodes[
                    v]['type'] == 'location':
                type = "connected to"
            else:
                type = "located in"
            print("{} {} {}".format(v, type, u))
            self.graph.add_edge(v, u, label=type)
            self.edge_labels[(v, u)] = type

    def export(self, filename="graph.dot"):
        nx.nx_pydot.write_dot(self.graph, filename)
        nx.write_gml(self.graph, "graph.gml", stringizer=None)

    def draw(self, filename="./graph.svg"):
        self.export()

        if args.write_sfdp:
            cmd = "sfdp -x -Goverlap=False -Tsvg graph.dot".format(filename)
            returned_value = subprocess.check_output(cmd, shell=True)
            with open(filename, 'wb') as f:
                f.write(returned_value)
            cmd = "inkscape -z -e {}.png {}.svg".format(
                filename[:-4], filename[:-4])
            returned_value = subprocess.check_output(cmd, shell=True)
        else:
            nx.draw(self.graph, with_labels=True)
            plt.savefig(filename[:-4] + '.png')
예제 #6
0
from flask import Flask, request, jsonify
from flask_cors import CORS

from bert import QA

app = Flask(__name__)
CORS(app)
modelName = "mrm8488/bert-small-finetuned-squadv2"
model = QA(modelName)


@app.route("/predict", methods=['POST'])
def predict():
    doc = request.json["document"]
    q = request.json["question"]
    try:
        out = model.predict(doc, q)
        return jsonify({"result": out})
    except Exception as e:
        print(e)
        return jsonify({"result": "Model Failed"})


if __name__ == "__main__":
    app.run('0.0.0.0', port=8000)
예제 #7
0
파일: api.py 프로젝트: davidlenz/BERT-SQuAD
from flask import Flask, request, jsonify
from flask_cors import CORS

from bert import QA

app = Flask(__name__)
CORS(app)

model = QA("bert-large-cased-whole-word-masking-finetuned-squad")


@app.route("/predict", methods=['POST'])
def predict():
    doc = request.json["document"]
    q = request.json["question"]
    try:
        out = model.predict(doc, q)
        return jsonify({"result": out})
    except Exception as e:
        print(e)
        return jsonify({"result": "Model Failed"})


if __name__ == "__main__":
    app.run('0.0.0.0', port=8000)
예제 #8
0
파일: api.py 프로젝트: artreven/BERT-SQuAD
import os

from flask import Flask, request, jsonify
from flask_cors import CORS
from dotenv import load_dotenv

from bert import QA

app = Flask(__name__)
CORS(app)

load_dotenv()
model = QA(os.getenv("OUTPUT_DIR"))


@app.route("/predict", methods=['POST'])
def predict():
    doc = request.json["document"]
    q = request.json["question"]
    try:
        out = model.predict(doc, q)
        return jsonify({"result": out})
    except Exception as e:
        app.logger.warning(e)
        return jsonify({"result": "Model Failed"})


if __name__ == "__main__":
    app.run('0.0.0.0', port=8000, debug=True)
예제 #9
0
from bert import QA

model = QA('model')

doc = "Victoria has a written constitution enacted in 1975, but based on the 1855 colonial constitution, passed by the United Kingdom Parliament as the Victoria Constitution Act 1855, which establishes the Parliament as the state's law-making body for matters coming under state responsibility. The Victorian Constitution can be amended by the Parliament of Victoria, except for certain 'entrenched' provisions that require either an absolute majority in both houses, a three-fifths majority in both houses, or the approval of the Victorian people in a referendum, depending on the provision."
doc = "According to the Indian census of 2001, there were 30,803,747 speakers of Malayalam in Kerala, making up 93.2% of the total number of Malayalam speakers in India, and 96.7% of the total population of the state. There were a further 701,673 (2.1% of the total number) in Karnataka, 557,705 (1.7%) in Tamil Nadu and 406,358 (1.2%) in Maharashtra. The number of Malayalam speakers in Lakshadweep is 51,100, which is only 0.15% of the total number, but is as much as about 84% of the population of Lakshadweep. In all, Malayalis made up 3.22% of the total Indian population in 2001. Of the total 33,066,392 Malayalam speakers in India in 2001, 33,015,420 spoke the standard dialects, 19,643 spoke the Yerava dialect and 31,329 spoke non-standard regional variations like Eranadan. As per the 1991 census data, 28.85% of all Malayalam speakers in India spoke a second language and 19.64% of the total knew three or more languages.  Large numbers of Malayalis have settled in Bangalore, Mangalore, Delhi, Coimbatore, Hyderabad, Mumbai (Bombay), Ahmedabad, Pune, and Chennai (Madras). A large number of Malayalis have also emigrated to the Middle East, the United States, and Europe. Accessed November 22, 2014.</ref> including a large number of professionals. There were 7,093 Malayalam speakers in Australia in 2006. The 2001 Canadian census reported 7,070 people who listed Malayalam as their mother tongue, mostly in the Greater Toronto Area and Southern Ontario. In 2010, the Census of Population of Singapore reported that there were 26,348 Malayalees in Singapore. The 2006 New Zealand census reported 2,139 speakers. 134 Malayalam speaking households were reported in 1956 in Fiji. There is also a considerable Malayali population in the Persian Gulf regions, especially in Bahrain, Muscat, Doha, Dubai, Abu Dhabi, Kuwait and European region mainly in London.  World Malayalee Council, the organisation working with the Malayali diaspora across the Globe has embarked upon a project for making a data bank of the diaspora. CANNOTANSWER"
q = 'When did Victoria enact its constitution?'
q = "What other languages are spoken there?"
answer = model.predict(doc, q)

print(answer['answer'])
# 1975

# dict_keys(['answer', 'start', 'end', 'confidence', 'document']))
예제 #10
0
from word2vec_repo.DocSim import DocSim
from nltk.corpus import stopwords
import nltk
import os
import numpy as np

#some of the code are from https://stackoverflow.com/a/8897648

model_path = 'word2vec_repo/model.bin'
stopwords = stopwords.words('english')

model = KeyedVectors.load(model_path)
ds = DocSim(model, stopwords=stopwords)

from bert import QA
model = QA('BERTap/model')


#fetch top 3 doc using tf-idf weighting method(for large corpus set it to 10)
def get_doc(qu):
    final_docs = matching_score(10, qu)
    #print(final_docs)
    #print(final_docs)
    nd = len(final_docs)
    if len(final_docs) > 3:
        nd = 3
    answer_doc = []
    for i in range(nd):
        file = open(final_docs[i], 'r', encoding='cp1250')
        text = file.read().strip()
        file.close()
예제 #11
0
from bert import QA

model = QA('model')

story = input('Enter story: ')
flag = 'y'
while flag == 'y' or flag == 'yes':
    ques = input('Enter question: ')
    answer1 = model.predict(story, ques)
    print('Answer is ', answer1['answer'])
    flag = input('Ask more: y/n ')
'''
doc = "Hey! I am Mehak. I love dancing. I live in Punjab. I have one brother and one sister. My father is a businessman. My mother is a teacher. I love my family."

q1 = 'What is my name?'
q2 = 'What are my hobbies?'
q3 = 'How many siblings do I have?'
q4 = 'What is the occupation of my mother?'
q5 = 'What is the occupation of my father?'

answer1 = model.predict(doc,q1)
answer2 = model.predict(doc,q2)
answer3 = model.predict(doc,q3)
answer4 = model.predict(doc,q4)
answer5 = model.predict(doc,q5)

print(answer1['answer'])
print(answer2['answer'])
print(answer3['answer'])
print(answer4['answer'])
print(answer5['answer'])
예제 #12
0
파일: run.py 프로젝트: shrey-bansal/Bert-QA
import time
import nltk

from bert import QA
model = QA('model')

doc = open("data.txt").read()
q = "Which department filed an antitrust lawsuit against Microsoft in 1998 ?"
s = ""
start_time = time.time()
words = nltk.word_tokenize(doc)
print("Words in the document are "+ str(len(words))+" words")
if(len(words)>512):
    answer = model.predict(doc,q,150)
else:
    answer = model.predict(doc,q,len(words))


if(len(answer)==0):
    s+="No"
else:
    s+="Yes, "
    for i in range(len(answer)):
        t = answer[i]
        s+= t['answer'] + ", "
# 1975
print(q)
print("Final answer: ",s)
print("Time taken in seconds: " , (time.time() - start_time))
예제 #13
0
        def load_model(_):

            model = QA('model')
            return model
예제 #14
0
from nltk.tag import StanfordPOSTagger
from nltk.tag import StanfordNERTagger
from sklearn import linear_model
from sklearn import svm
from sklearn.metrics import fbeta_score, accuracy_score
from scipy.sparse import hstack

from nltk.stem.porter import PorterStemmer
from nltk.stem.snowball import SnowballStemmer
from nltk.stem.wordnet import WordNetLemmatizer

import requests

from bert import QA

model = QA('model')

context = ' Bits Pilani is a private institute of higher education and a deemed university under Section 3 of the UGC Act 1956. \
The institute was established in its present form in 1964. It is established across 4 campuses and has 15 academic departments. \
Pilani is located 220 kilometres far from Delhi in Rajasthan. We can reach bits pilani via train or bus from delhi. \
Bits Pilani has its campuses in Pilani , Goa , Hyderabad , Dubai. There are multiple scholarships available at BITS namely Merit Scholarships, Merit Cum Need Scholarships and BITSAA Scholarships. \
BITS Model United Nations Conference (BITSMUN) is one of the largest MUN conferences in the country. BITS conducts the All-India computerized entrance examination, BITSAT (BITS Admission Test). \
Admission is merit-based, as assessed by the BITSAT examination. \
We can reach bits pilani through bus or train from delhi or jaipur. \
Mr. Ashoke Kumar Sarkar is the director of Bits Pilani, pilani campus. \
Founder of Bits pilani was Ghanshyam Das Birla.'


def get_answer(context: str, ques: str):
    answer = model.predict(context, ques)
    return answer['answer']
예제 #15
0
from flask import Flask, request, jsonify
import flask_cors
import os

from bert import QA
from gcp import GCP

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "fyp-qa-eb7816dfb87e.json"
os.environ["CUDA_VISIBLE_DEVICES"] = ""

app = Flask(__name__, static_url_path='/static')
flask_cors.CORS(app)

model = QA()
gcp = GCP()


@app.route("/file/raw", methods=['POST'])
def file_upload_raw():
    try:
        filename = request.json["filename"]
        text = request.json["text"]

        if filename == '' or text == '':
            return 'No file text or name empty .', 400

        doc_id = gcp.upload_raw(filename, text)

        return jsonify({'success': True, 'id': doc_id})
    except Exception as e:
        print(e)
예제 #16
0
import time

basedir = os.path.abspath(os.path.dirname(__file__))
app = Flask(__name__)
SECRET_KEY = os.urandom(32)
app.config['SECRET_KEY'] = SECRET_KEY

MODEL_PATH=os.path.join(basedir, 'model')
DATA_PATH=os.path.join(basedir, 'dataUIT')
data=load_data(DATA_PATH)
stopwords = set(open(basedir+'\stopwords.txt',encoding="utf-8").read().split(' ')[:-1])

#load model
print("Load model...")
start = time.time()
model=QA(MODEL_PATH) #path to model
end = time.time()
print("time load model: "+str(round((end - start),2)))

#Building index
print('Building index...')
start = time.time()
data_standard=standardize_data(data,stopwords)
vect = TfidfVectorizer(min_df=1, max_df=0.8,max_features=5000,sublinear_tf=True,ngram_range=(1,3)) 
vect.fit(data_standard)
end = time.time()
print("Time building index: "+str(round((end - start),2)))



@app.route('/', methods=['GET','POST'])
예제 #17
0
# see original BERT-SQuAD : #https://github.com/kamalkraj/BERT-SQuAD
#############################################################
#Copyright 2019 Pierre Rouarch
# License GPL V3
#############################################################
myKeyword = "When Abraham Lincoln died and how?"
from bert import QA

n_best_size = 20
#list of pretrained model
#https://huggingface.co/transformers/pretrained_models.html
#!!!!!  instantiate a BERT model  fine tuned on SQuAD
#Choose your model
#'bert-large-uncased-whole-word-masking-finetuned-squad'
#'bert-large-cased-whole-word-masking-finetuned-squad'
model = QA('bert-large-uncased-whole-word-masking-finetuned-squad',
           n_best_size)

#import needed libraries
import pandas as pd
import numpy as np
#pip install google  #to install Google Search by Mario Vilas see
#https://python-googlesearch.readthedocs.io/en/latest/
import googlesearch  #Scrap serps
#to randomize pause
import random
import time  #to calcute page time downlod
from datetime import date
import sys  #for sys variables

import requests  #to read urls contents
from bs4 import BeautifulSoup  #to decode html
from bert import QA
import pandas as pd
import re
import csv
from datetime import datetime
from utils import *
time_str = datetime.now().strftime('%Y%m%d%H%M%S')

bank_filepath = 'itembank.csv'
model = QA('model')

# Load the database
content_code = 'A'
df = pd.read_csv(bank_filepath)

# Data cleanup
df = df.replace("[\[].*?[\]]", "", regex=True)  # remove all bracket stylings
df = df.replace("\(select.*?\)", "", regex=True)  # remove all select prompts
df = df.replace('\s+', ' ', regex=True)  # truncate all multiple spaces

# Data filters
context = (df['ContentCode'] == content_code) \
    & (~df['Stem'].str.contains('arrow|points to|the line|line #|1. |video|label|cursor|click|\d{3}') # Attempt to filter out questions that refer to images
    & (df['FullKey'] > 'D')                             # Filter out questions with more than 4 options
    & (df['FullKey'].str.len() == 1))                   # Filter out questions with multiple answers, will not work with model prediction

# Inits
iterator = 0
length = df[context].shape[0]
total = int((length * (length + 1)) / 2)
예제 #19
0
from bert import QA
import os
model = QA('model')

def read_all_data(word2idx, max_words, max_sentences):
    # stories[story_ind] = [[sentence1], [sentence2], ..., [sentenceN]]
    # questions[question_ind] = {'question': [question], 'answer': [answer], 'story_index': #, 'sentence_index': #}
    stories = dict()
    questions = dict()
    
    
    if len(word2idx) == 0:
        word2idx['<null>'] = 0

    data_dir = "./tasksv11/en-valid"
    for babi_task in range(1,21):
        fname = '{}/qa{}_test.txt'.format(data_dir, babi_task)    
        if os.path.isfile(fname):
            with open(fname) as f:
                lines = f.readlines()
        else:
            raise Exception("[!] Data {file} not found".format(file=fname))

        for line in lines:
            words = line.split()
            max_words = max(max_words, len(words))
            
            # Determine whether the line indicates the start of a new story
            if words[0] == '1':
                story_ind = len(stories)
                sentence_ind = 0
예제 #20
0
import gradio as gr
import os, sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__), "utils"))
from bert import QA

model = QA('bert-large-uncased-whole-word-masking-finetuned-squad')
def qa_func(paragraph, question):
    return model.predict(paragraph, question)["answer"]

iface = gr.Interface(qa_func, 
    [
        gr.inputs.Textbox(lines=7, label="Context", default="Victoria has a written constitution enacted in 1975, but based on the 1855 colonial constitution, passed by the United Kingdom Parliament as the Victoria Constitution Act 1855, which establishes the Parliament as the state's law-making body for matters coming under state responsibility. The Victorian Constitution can be amended by the Parliament of Victoria, except for certain 'entrenched' provisions that require either an absolute majority in both houses, a three-fifths majority in both houses, or the approval of the Victorian people in a referendum, depending on the provision."), 
        gr.inputs.Textbox(lines=1, label="Question", default="When did Victoria enact its constitution?"), 
    ], 
    gr.outputs.Textbox(label="Answer"))
if __name__ == "__main__":
    iface.launch()
예제 #21
0
from flask import Flask,request,jsonify
from flask_cors import CORS

from bert import QA

app = Flask(__name__)
CORS(app)

#model = QA("model")
model = QA("/home/k3ijo/bert/nlp_model/multi_cased_L-12_H-768_A-12/")

@app.route("/predict",methods=['POST'])
def predict():
    doc = request.json["document"]
    q = request.json["question"]
    try:
        out = model.predict(doc,q)
        return jsonify({"result":out})
    except Exception as e:
        print(e)
        return jsonify({"result":"Model Failed"})

if __name__ == "__main__":
    app.run('0.0.0.0',port=8000)
예제 #22
0
def test(use_jit=False,
         fp16=False,
         onnx_runtime=False,
         export_onnx=False,
         tf_onnx=False,
         tf_version=False,
         vsl='none',
         min_batch=0,
         max_batch=1,
         num_predicts=300):
    document1 = 'Two partially reusable launch systems were developed, the Space Shuttle and Falcon 9. ' \
               'The Space Shuttle was partially reusable: the orbiter (which included the Space Shuttle ' \
               'main engines and the Orbital Maneuvering System engines), and the two solid rocket boosters ' \
               'were reused after several months of refitting work for each launch. The external tank was ' \
               'discarded after each flight. and the two solid rocket boosters were reused after several ' \
               'months of refitting work for each launch. The external tank was discarded after each flight.'
    document2 = 'This contrasts with expendable launch systems, where each launch vehicle is launched once ' \
                'and then discarded. No completely reusable orbital launch system has ever been created.'
    document3 = 'A reusable launch system (RLS, or reusable launch vehicle, RLV) is a launch system which is ' \
                'capable of launching a payload into space more than once. This contrasts with expendable ' \
                'launch systems, where each launch vehicle is launched once and then discarded. No completely ' \
                'reusable orbital launch system has ever been created.'
    question = 'How many partially reusable launch systems were developed?'
    # passages = [document1, document2, document3, document1, document2, document3, document1, document2, document3]
    # passages = [document1, document2, document3]
    # passages = [document1]

    if tf_onnx or tf_version:
        from multiprocessing import Pool

        convert_onnx_to_tf = False
        if tf_onnx and convert_onnx_to_tf:
            onnx_model = onnx.load(ONNX_PATH)
            # prepare tf representation
            tf_exp = onnx_tf.backend.prepare(onnx_model)
            # export the model
            tf_exp.export_graph(ONNX_TF_PB_PATH)

        onnx_pb_graph = tf.Graph()
        with onnx_pb_graph.as_default():
            tf_pb_path = ONNX_TF_PB_PATH if tf_onnx else TF_PB_PATH
            onnx_pb_graph_def = tf.GraphDef()
            with tf.gfile.GFile(tf_pb_path, 'rb') as fid:
                serialized_graph = fid.read()

            onnx_pb_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(onnx_pb_graph_def, name='')

            config = tf.ConfigProto()
            config.gpu_options.allow_growth = True
            if use_jit:
                # config.gpu_options.per_process_gpu_memory_fraction = 0.5
                config.log_device_placement = False
                config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

            with tf.Session(config=config) as sess:
                # INFERENCE using session.run
                model = QA(MODEL_PATH,
                           use_jit=use_jit,
                           fp16=fp16,
                           onnx=onnx_runtime,
                           sess=sess,
                           vsl=vsl,
                           tf_onnx=tf_onnx)

                print('-- BENCHMARKING: JIT={} | FP16={} | ONNX_RUNTIME={} | '
                      'TF_ONNX_VERSION={} | TF_VERSION={} | EXACT_VSL={} --'.
                      format(use_jit, fp16, onnx_runtime, tf_onnx, tf_version,
                             vsl))
                for passage_batch in range(min_batch, max_batch):
                    passage_batch = pow(3, passage_batch - 1)
                    if passage_batch < 1:
                        passages = [document1]
                    else:
                        passages = []
                        for i in range(passage_batch):
                            passages.append(document1)
                            passages.append(document2)
                            passages.append(document3)

                    if max_batch > 2:
                        num_predicts = 50
                    time_taken, rps = measure_inference(
                        model, passages, question, num_predicts)
                    # print('Time taken for test: {} s'.format(time_taken))
                    print('RPS: {}'.format(rps))

                sess.close()

            del model, sess
    else:
        model = QA(MODEL_PATH,
                   use_jit=use_jit,
                   fp16=fp16,
                   onnx=onnx_runtime,
                   export_onnx=export_onnx,
                   vsl=vsl,
                   onnx_path=ONNX_PATH)

        if not export_onnx:
            print(
                '-- BENCHMARKING: JIT={} | FP16={} | ONNX_RUNTIME={} | '
                'TF_ONNX_VERSION={} | TF_VERSION={} | EXACT_VSL={} --'.format(
                    use_jit, fp16, onnx_runtime, tf_onnx, tf_version, vsl))
            for passage_batch in range(min_batch, max_batch):
                passage_batch = pow(3, passage_batch - 1)
                if passage_batch < 1:
                    passages = [document1]
                else:
                    passages = []
                    for i in range(passage_batch):
                        passages.append(document1)
                        passages.append(document2)
                        passages.append(document3)

                if max_batch > 2:
                    num_predicts = 50
                time_taken, rps = measure_inference(model, passages, question,
                                                    num_predicts)
                # print('Time taken for test: {} s'.format(time_taken))
                print('RPS: {}'.format(rps))
        del model
        torch.cuda.empty_cache()
예제 #23
0
"""
Make sure bert.py exists in the same directory as this script and the model files downloaded form the dropbox link is placed under the model folder
"""

from bert import QA

model = QA('model')
"""
Now, let us implement comprehending a passage from google with our existing zero shot model
"""

passage = 'There was a princess called Maggie, she was tall and as white as snow. She had red lips like a \
rose and her hair was brown. She had light blue eyes and was very nice and kind. \
She was in love with a bricklayer called Kevin. He was tall, had brown hair and tanned skin. \
He was strong, had dark eyes, was kind and he was never angry'

ques = 'What is the name of the princess?'
"""
Generally the predict method gives out a dictionary with other values as well, but since our interest lies with answer we subset the same.
"""
answer = model.predict(passage, ques)['answer']

print("{} : {}".format(ques, answer))
"""
Let us try few more questions to see the edge case scenarios to understand where the model might break
"""

answer = model.predict(passage, 'Who was Maggie?')['answer']

print("{} : {}".format('Who was Maggie?', answer))
예제 #24
0
from flask import Flask, render_template, request, send_file
import warnings
warnings.filterwarnings("ignore")
from bert import QA

DEBUG = False
app = Flask(__name__)
app.config.from_object(__name__)


model = QA("model")

@app.route('/')
def single_people_code():
    return render_template('index.html')

@app.route("/",methods=['POST','GET'])
def predict():
    if request.method == 'POST':

        form = request.form
        results = request.form

        doc = results['passage']

        q = results['question']

        try:


            out = model.predict(doc,q)
예제 #25
0
import sys
sys.path.append('./BERT-SQuAD')
import json
import torch
from bert import QA


def predict(doc, q):
    answer = model.predict(doc, q)
    content = json.dumps(answer, separators=(',', ':'))
    print(content, flush=True)


print("Loading model!", flush=True)
model = QA('model')
print("Model loading complete!", flush=True, end='')
model.predict("Whales are a kind of animal called a mammal.",
              "What is a whale?")