def _build_full_dataset(self): samples = [] for question in self.questions: pars_and_scores = list( zip(question.supporting_facts + question.distractors, question.gold_scores + question.distractor_scores)) higher_gold = question.supporting_facts[0] \ if question.gold_scores[0] >= question.gold_scores[1] else question.supporting_facts[1] for p1, score1 in pars_and_scores: for p2, score2 in pars_and_scores: first_label, second_label = self._get_labels( is_gold1=p1 in question.supporting_facts, is_gold2=p2 in question.supporting_facts, q_type=question.q_type, are_same=p1 == p2, is_first_higher=higher_gold == p1) samples.append( IterativeQuestionAndParagraphs( question=question.question_tokens, paragraphs=[ flatten_iterable(p1.sentences), flatten_iterable(p2.sentences) ], first_label=first_label, second_label=second_label, question_id=question.question_id, q_type=question.q_type, sentence_segments=[ get_segments_from_sentences(s) for s in [p1.sentences, p2.sentences] ])) return samples
def reformulate_questions_from_texts(self, tokenized_questions: List[List[str]], tokenized_pars: List[List[str]], return_search_vectors: bool, show_progress=False, max_batch=None): dummy_par = "Hello Hello".split() batch = [ IterativeQuestionAndParagraphs(question=q, paragraphs=[p, dummy_par], first_label=1, second_label=1, question_id='dummy', sentence_segments=None) for q, p in zip(tokenized_questions, tokenized_pars) ] final_encs = [] batch_size = self.model.max_batch_size if self.model.max_batch_size else len( batch) batch_size = batch_size if max_batch is None else min( batch_size, max_batch) for _ in range(0, len(batch), batch_size) if not show_progress else tqdm( range(0, len(batch), batch_size)): feed_dict = self.model.encode(batch[:batch_size], False) reformulations = self.sess.run(self.reformulation_name, feed_dict=feed_dict) final_encs.append(reformulations) batch = batch[batch_size:] reformulations = np.concatenate(final_encs, axis=0) if return_search_vectors: return self.question_rep_to_search_vector(reformulations, context_idx=2) return reformulations
def _sample_false_3(self, qid): """ False sample of type 2: gold from other question """ question = self.qid2question[qid] rand_q_idx = self.random.randint(len(self.gold_samples)) while self.gold_samples[ rand_q_idx].question_id == question.question_id: rand_q_idx = self.random.randint(len(self.gold_samples)) selected_q = self.gold_samples[rand_q_idx] return IterativeQuestionAndParagraphs( question=question.question_tokens, paragraphs=[x for x in selected_q.paragraphs], first_label=0, second_label=0, question_id=question.question_id, q_type=question.q_type, sentence_segments=[x for x in selected_q.sentence_segments])
def encode_text_questions(self, tokenized_questions: List[List[str]], return_search_vectors: bool, show_progress=False): dummy_par = "Hello Hello".split() samples = [ IterativeQuestionAndParagraphs(question=q, paragraphs=[dummy_par, dummy_par], first_label=1, second_label=1, question_id='dummy', sentence_segments=None) for q in tokenized_questions ] return self.encode_questions( samples, return_search_vectors=return_search_vectors, show_progress=show_progress)
def _sample_false_1(self, qid): """ False sample of type 1: all distractors. No sampling from other question here, as I think it's less effective in this case""" question = self.qid2question[qid] two_distractors = self.random.choice(question.distractors, size=2, replace=False) pars = [flatten_iterable(x.sentences) for x in two_distractors] segs = [ get_segments_from_sentences(x.sentences) for x in two_distractors ] return IterativeQuestionAndParagraphs( question=question.question_tokens, paragraphs=pars, first_label=0, second_label=0, question_id=question.question_id, q_type=question.q_type, sentence_segments=segs)
def _sample_first_gold_second_false(self, qid): question = self.qid2question[qid] rand_par_other_q = self._sample_rand_par_other_q(qid) if question.q_type == 'comparison' or self.bridge_as_comparison: first_gold_par = question.supporting_facts[self.random.randint(2)] else: if not self.label_by_span: first_gold_idx = 0 if question.gold_scores[ 0] > question.gold_scores[1] else 1 first_gold_par = question.supporting_facts[first_gold_idx] else: gold_idxs = self._get_no_span_containing_golds( question.question_id) if len(gold_idxs) == 1: first_gold_par = question.supporting_facts[gold_idxs[0]] else: first_gold_par = question.supporting_facts[ self.random.randint(2)] rand_par = self.random.choice([ rand_par_other_q, first_gold_par, self.random.choice(question.distractors) ], p=[0.05, 0.1, 0.85]) pars = [ flatten_iterable(first_gold_par.sentences), flatten_iterable(rand_par.sentences) ] segs = [ get_segments_from_sentences(first_gold_par.sentences), get_segments_from_sentences(rand_par.sentences) ] return IterativeQuestionAndParagraphs( question=question.question_tokens, paragraphs=pars, first_label=1, second_label=0, question_id=question.question_id, q_type=question.q_type, sentence_segments=segs)
def _sample_false_2(self, qid): """ False sample of type 2: first distractor, second one of supporting facts """ question = self.qid2question[qid] rand_par_other_q = self._sample_rand_par_other_q(qid) distractor = self.random.choice( [self.random.choice(question.distractors), rand_par_other_q], p=[0.9, 0.1]) gold = self.random.choice(question.supporting_facts) pars = [flatten_iterable(x.sentences) for x in [distractor, gold]] segs = [ get_segments_from_sentences(x.sentences) for x in [distractor, gold] ] return IterativeQuestionAndParagraphs( question=question.question_tokens, paragraphs=pars, first_label=0, second_label=0, question_id=question.question_id, q_type=question.q_type, sentence_segments=segs)
def _build_gold_samples(self): gold_samples = [] for question in self.questions: if question.q_type == 'comparison' or self.bridge_as_comparison: pars_order = [0, 1] self.random.shuffle(pars_order) else: if not self.label_by_span: pars_order = [0, 1] if question.gold_scores[ 0] > question.gold_scores[1] else [1, 0] else: gold_idxs = self._get_no_span_containing_golds( question.question_id) pars_order = [0, 1] if 0 in gold_idxs else [1, 0] if len( gold_idxs ) != 1: # either both contain the answer or both don't contain, so regarded equal self.random.shuffle(pars_order) pars = [ flatten_iterable(question.supporting_facts[i].sentences) for i in pars_order ] sentence_segs = [ get_segments_from_sentences( question.supporting_facts[i].sentences) for i in pars_order ] gold_samples.append( IterativeQuestionAndParagraphs( question.question_tokens, pars, first_label=1, second_label=1, question_id=question.question_id, q_type=question.q_type, sentence_segments=sentence_segs)) return gold_samples
def encode_from_file(docs_file, questions_file, encodings_dir, encoder_model, num_workers, hotpot: bool, long_batch: int, short_batch: int, use_chars: bool, use_ema: bool, checkpoint: str, document_chunk_size=1000, samples=None, encode_all_db=False): """ :param out_file: .npz file to dump the encodings :param docs_file: path to json file whose structure is [{title: list of paragraphs}, ...] :return: """ doc_encs_handler = DocumentEncodingHandler(encodings_dir) # Setup worker pool workers = ProcessPool(num_workers, initializer=init, initargs=[]) if docs_file is not None: with open(docs_file, 'r') as f: documents = json.load(f) documents = { k: v for k, v in documents.items() if k not in doc_encs_handler.titles2filenames } tokenized_documents = {} tupled_doc_list = [(title, pars) for title, pars in documents.items()] if samples is not None: print(f"sampling {samples} samples") tupled_doc_list = tupled_doc_list[:samples] print("Tokenizing from file...") with tqdm(total=len(tupled_doc_list), ncols=80) as pbar: for tok_doc in tqdm( workers.imap_unordered(tokenize_document, tupled_doc_list)): tokenized_documents.update(tok_doc) pbar.update() else: if questions_file is not None: with open(questions_file, 'r') as f: questions = json.load(f) all_titles = list( set([title for q in questions for title in q['top_titles']])) else: print("encoding all DB!") all_titles = DocDB().get_doc_titles() if samples is not None: print(f"sampling {samples} samples") all_titles = all_titles[:samples] all_titles = [ t for t in all_titles if t not in doc_encs_handler.titles2filenames ] tokenized_documents = {} print("Tokenizing from DB...") with tqdm(total=len(all_titles), ncols=80) as pbar: for tok_doc in tqdm( workers.imap_unordered(tokenize_from_db, all_titles)): tokenized_documents.update(tok_doc) pbar.update() workers.close() workers.join() voc = set() for paragraphs in tokenized_documents.values(): for par in paragraphs: voc.update(par) if not hotpot: spec = QuestionAndParagraphsSpec(batch_size=None, max_num_contexts=1, max_num_question_words=None, max_num_context_words=None) encoder = SentenceEncoderSingleContext(model_dir_path=encoder_model, vocabulary=voc, spec=spec, loader=ResourceLoader(), use_char_inputs=use_chars, use_ema=use_ema, checkpoint=checkpoint) else: spec = QuestionAndParagraphsSpec(batch_size=None, max_num_contexts=2, max_num_question_words=None, max_num_context_words=None) encoder = SentenceEncoderIterativeModel(model_dir_path=encoder_model, vocabulary=voc, spec=spec, loader=ResourceLoader(), use_char_inputs=use_chars, use_ema=use_ema, checkpoint=checkpoint) tokenized_documents_items = list(tokenized_documents.items()) for tokenized_doc_chunk in tqdm([ tokenized_documents_items[i:i + document_chunk_size] for i in range(0, len(tokenized_documents_items), document_chunk_size) ], ncols=80): flattened_pars_with_names = [(f"{title}_{i}", par) for title, pars in tokenized_doc_chunk for i, par in enumerate(pars)] # filtering out empty paragraphs (probably had some short string the tokenization removed) # important to notice that the filtered paragraphs will have no representation, # but they still exist in the numbering of paragraphs for consistency with the docs. flattened_pars_with_names = [(name, par) for name, par in flattened_pars_with_names if len(par) > 0] # sort such that longer paragraphs are first to identify OOMs early on flattened_pars_with_names = sorted(flattened_pars_with_names, key=lambda x: len(x[1]), reverse=True) long_paragraphs_ids = [ i for i, name_par in enumerate(flattened_pars_with_names) if len(name_par[1]) >= 900 ] short_paragraphs_ids = [ i for i, name_par in enumerate(flattened_pars_with_names) if len(name_par[1]) < 900 ] # print(f"Encoding {len(flattened_pars_with_names)} paragraphs...") name2enc = {} dummy_question = "Hello Hello".split() if not hotpot: model_paragraphs = [ BinaryQuestionAndParagraphs(question=dummy_question, paragraphs=[x], label=1, num_distractors=0, question_id='dummy') for _, x in flattened_pars_with_names ] else: # todo allow precomputed sentence segments model_paragraphs = [ IterativeQuestionAndParagraphs(question=dummy_question, paragraphs=[x, dummy_question], first_label=1, second_label=1, question_id='dummy', sentence_segments=None) for _, x in flattened_pars_with_names ] # print("Encoding long paragraphs...") long_pars = [model_paragraphs[i] for i in long_paragraphs_ids] name2enc.update({ flattened_pars_with_names[long_paragraphs_ids[i]][0]: enc for i, enc in enumerate( encoder.encode_paragraphs( long_pars, batch_size=long_batch, show_progress=True ) if not hotpot else encoder.encode_first_paragraphs( long_pars, batch_size=long_batch, show_progress=True)) }) # print("Encoding short paragraphs...") short_pars = [model_paragraphs[i] for i in short_paragraphs_ids] name2enc.update({ flattened_pars_with_names[short_paragraphs_ids[i]][0]: enc for i, enc in enumerate( encoder.encode_paragraphs( short_pars, batch_size=short_batch, show_progress=True ) if not hotpot else encoder.encode_first_paragraphs( short_pars, batch_size=short_batch, show_progress=True)) }) doc_encs_handler.save_multiple_documents(name2enc)