def read_input_data(model): data = [] vocab = set() tokenizer = NltkAndPunctTokenizer() with open(OPTS.input_file) as f: json_data = json.load(f) for doc in json_data['data']: for paragraph in doc['paragraphs']: context = tokenizer.tokenize_with_inverse(paragraph['context']) if model.preprocessor is not None: context = model.preprocessor.encode_text(question, context) context = context.get_context() vocab.update(context) for qa in paragraph['qas']: question = tokenizer.tokenize_sentence(qa['question']) vocab.update(question) ex = [ParagraphAndQuestion(context, question, None, qa['id'])] data.append((paragraph['context'], context, ex)) return data, sorted(list(vocab))
class QaSystem(object): """ End-to-end QA system, uses web-requests to get relevant documents and a model to scores candidate answer spans. """ _split_regex = re.compile( "\s*\n\s*") # split includes whitespace to avoid empty paragraphs def __init__(self, wiki_cache: str, paragraph_splitter: DocumentSplitter, paragraph_selector: ParagraphFilter, vocab: Union[str, Set[str]], model: Union[ParagraphQuestionModel, ModelDir], loader: ResourceLoader = ResourceLoader(), bing_api_key=None, tagme_api_key=None, blacklist_trivia_sites: bool = False, n_dl_threads: int = 5, span_bound: int = 8, tagme_threshold: Optional[float] = 0.2, download_timeout: int = None, n_web_docs=10): self.log = logging.getLogger('qa_system') self.tagme_threshold = tagme_threshold self.n_web_docs = n_web_docs self.blacklist_trivia_sites = blacklist_trivia_sites self.tagme_api_key = tagme_api_key if bing_api_key is not None: self.searcher = AsyncWebSearcher(bing_api_key) self.text_extractor = AsyncBoilerpipeCliExtractor( n_dl_threads, download_timeout) else: self.text_extractor = None self.searcher = None self.wiki_corpus = WikiCorpus(wiki_cache, keep_inverse_mapping=True) self.paragraph_splitter = paragraph_splitter self.paragraph_selector = paragraph_selector self.model_dir = model voc = None if vocab is not None: if isinstance(vocab, str): voc = set() with open(vocab, "r") as f: for line in f: voc.add(line.strip()) else: voc = vocab self.log.info("Using preset vocab of size %d", len(voc)) self.log.info("Setting up model...") if isinstance(model, ModelDir): self.model = model.get_model() else: self.model = model self.model.set_input_spec(ParagraphAndQuestionSpec(None), voc, loader) self.sess = tf.Session() with self.sess.as_default(): pred = self.model.get_prediction() model.restore_checkpoint(self.sess) self.span_scores = pred.get_span_scores() self.span, self.score = pred.get_best_span(span_bound) self.tokenizer = NltkAndPunctTokenizer() self.sess.graph.finalize() async def answer_question( self, question: str) -> Tuple[np.ndarray, List[WebParagraph]]: """ Answer a question using web search """ context = await self.get_question_context(question) question = self.tokenizer.tokenize_paragraph_flat(question) t0 = time.perf_counter() out = self._get_span_scores(question, context) self.log.info("Computing answer spans took %.5f seconds" % (time.perf_counter() - t0)) return out def answer_with_doc(self, question: str, doc: str) -> Tuple[np.ndarray, List[WebParagraph]]: """ Answer a question using the given text as a document """ self.log.info("Answering question \"%s\" with a given document" % question) # Tokenize question = self.tokenizer.tokenize_paragraph_flat(question) context = [ self.tokenizer.tokenize_with_inverse(x, False) for x in self._split_regex.split(doc) ] # Split into super-paragraphs context = self._split_document(context, "User", None) # Select top paragraphs context = self.paragraph_selector.prune(question, context) if len(context) == 0: raise ValueError("Unable to process documents") # Select the top answer span t0 = time.perf_counter() span_scores = self._get_span_scores(question, context) self.log.info("Computing answer spans took %.5f seconds" % (time.perf_counter() - t0)) return span_scores def _get_span_scores(self, question: List[str], paragraphs: List[ParagraphWithInverse]): """ Answer a question using the given paragraphs, returns both the span scores and the pre-processed paragraphs the span are valid for """ if self.model.preprocessor is not None: prepped = [] for para in paragraphs: if hasattr(para, "spans"): spans = para.spans else: spans = None text, _, inv = self.model.preprocessor.encode_paragraph( [], para.text, para.start == 0, np.zeros((0, 2), dtype=np.int32), spans) prepped.append( WebParagraph([text], para.original_text, inv, para.paragraph_num, para.start, para.end, para.source_name, para.source_url)) paragraphs = prepped qa_pairs = [ ParagraphAndQuestion(c.get_context(), question, None, "") for c in paragraphs ] encoded = self.model.encode(qa_pairs, False) return self.sess.run(self.span_scores, encoded), paragraphs def _split_document(self, para: List[ParagraphWithInverse], source_name: str, source_url: Optional[str]): tokenized_paragraphs = [] on_token = 0 for i, para in enumerate(self.paragraph_splitter.split_inverse(para)): n_tokens = para.n_tokens tokenized_paragraphs.append( WebParagraph(para.text, para.original_text, para.spans, i + 1, on_token, on_token + n_tokens, source_name, source_url)) on_token += n_tokens return tokenized_paragraphs async def _tagme(self, question): payload = { "text": question, "long_text": 3, "lang": "en", "gcube-token": self.tagme_api_key } async with ClientSession() as sess: async with sess.get(url=TAGME_API, params=payload) as resp: data = await resp.json() return [ ann_json for ann_json in data["annotations"] if "title" in ann_json ] async def get_question_context(self, question: str) -> List[WebParagraph]: """ Find a set of paragraphs from the web that are relevant to the given question """ tokenized_paragraphs = [] if self.tagme_threshold is not None: self.log.info("Query tagme for %s", question) tags = await self._tagme(question) t0 = time.perf_counter() found = set() for tag in tags: if tag["rho"] >= self.tagme_threshold: title = tag["title"] if title in found: continue found.add(title) doc = await self.wiki_corpus.get_wiki_article(title) tokenized_paragraphs += self._split_document( doc.paragraphs, "Wikipedia: " + doc.title, doc.url) if len(tokenized_paragraphs) > 0: self.log.info("Getting wiki docs took %.5f seconds" % (time.perf_counter() - t0)) if self.n_web_docs > 0: t0 = time.perf_counter() self.log.info("Running bing search for %s", question) search_results = await self.searcher.run_search( question, self.n_web_docs) t1 = time.perf_counter() self.log.info("Completed bing search, took %.5f seconds" % (t1 - t0)) t0 = t1 url_to_result = {x["url"]: x for x in search_results} self.log.info("Extracting text for %d results", len(search_results)) text_docs = await self.text_extractor.get_text( [x["url"] for x in search_results]) for doc in text_docs: if len(doc.text) == 0: continue search_r = url_to_result[doc.url] if self.blacklist_trivia_sites: lower = search_r["displayUrl"].lower() if 'quiz' in lower or 'trivia' in lower or 'answer' in lower: # heuristic to ignore trivia sites, recommend by Mandar self.log.debug("Skipping trivia site: " + lower) continue paras_text = self._split_regex.split(doc.text.strip()) paras_tokenized = [ self.tokenizer.tokenize_with_inverse(x) for x in paras_text ] tokenized_paragraphs += self._split_document( paras_tokenized, search_r["displayUrl"], doc.url) self.log.info("Completed extracting text, took %.5f seconds." % (time.perf_counter() - t0)) self.log.info("Have %d paragraphs", len(tokenized_paragraphs)) if len(tokenized_paragraphs) == 0: return [] question = self.tokenizer.tokenize_sentence(question) return self.paragraph_selector.prune(question, tokenized_paragraphs)