コード例 #1
0
 def _build_full_dataset(self):
     samples = []
     for question in self.questions:
         pars_and_scores = list(
             zip(question.supporting_facts + question.distractors,
                 question.gold_scores + question.distractor_scores))
         higher_gold = question.supporting_facts[0] \
             if question.gold_scores[0] >= question.gold_scores[1] else question.supporting_facts[1]
         for p1, score1 in pars_and_scores:
             for p2, score2 in pars_and_scores:
                 first_label, second_label = self._get_labels(
                     is_gold1=p1 in question.supporting_facts,
                     is_gold2=p2 in question.supporting_facts,
                     q_type=question.q_type,
                     are_same=p1 == p2,
                     is_first_higher=higher_gold == p1)
                 samples.append(
                     IterativeQuestionAndParagraphs(
                         question=question.question_tokens,
                         paragraphs=[
                             flatten_iterable(p1.sentences),
                             flatten_iterable(p2.sentences)
                         ],
                         first_label=first_label,
                         second_label=second_label,
                         question_id=question.question_id,
                         q_type=question.q_type,
                         sentence_segments=[
                             get_segments_from_sentences(s)
                             for s in [p1.sentences, p2.sentences]
                         ]))
     return samples
コード例 #2
0
 def reformulate_questions_from_texts(self,
                                      tokenized_questions: List[List[str]],
                                      tokenized_pars: List[List[str]],
                                      return_search_vectors: bool,
                                      show_progress=False,
                                      max_batch=None):
     dummy_par = "Hello Hello".split()
     batch = [
         IterativeQuestionAndParagraphs(question=q,
                                        paragraphs=[p, dummy_par],
                                        first_label=1,
                                        second_label=1,
                                        question_id='dummy',
                                        sentence_segments=None)
         for q, p in zip(tokenized_questions, tokenized_pars)
     ]
     final_encs = []
     batch_size = self.model.max_batch_size if self.model.max_batch_size else len(
         batch)
     batch_size = batch_size if max_batch is None else min(
         batch_size, max_batch)
     for _ in range(0, len(batch),
                    batch_size) if not show_progress else tqdm(
                        range(0, len(batch), batch_size)):
         feed_dict = self.model.encode(batch[:batch_size], False)
         reformulations = self.sess.run(self.reformulation_name,
                                        feed_dict=feed_dict)
         final_encs.append(reformulations)
         batch = batch[batch_size:]
     reformulations = np.concatenate(final_encs, axis=0)
     if return_search_vectors:
         return self.question_rep_to_search_vector(reformulations,
                                                   context_idx=2)
     return reformulations
コード例 #3
0
 def _sample_false_3(self, qid):
     """ False sample of type 2: gold from other question """
     question = self.qid2question[qid]
     rand_q_idx = self.random.randint(len(self.gold_samples))
     while self.gold_samples[
             rand_q_idx].question_id == question.question_id:
         rand_q_idx = self.random.randint(len(self.gold_samples))
     selected_q = self.gold_samples[rand_q_idx]
     return IterativeQuestionAndParagraphs(
         question=question.question_tokens,
         paragraphs=[x for x in selected_q.paragraphs],
         first_label=0,
         second_label=0,
         question_id=question.question_id,
         q_type=question.q_type,
         sentence_segments=[x for x in selected_q.sentence_segments])
コード例 #4
0
 def encode_text_questions(self,
                           tokenized_questions: List[List[str]],
                           return_search_vectors: bool,
                           show_progress=False):
     dummy_par = "Hello Hello".split()
     samples = [
         IterativeQuestionAndParagraphs(question=q,
                                        paragraphs=[dummy_par, dummy_par],
                                        first_label=1,
                                        second_label=1,
                                        question_id='dummy',
                                        sentence_segments=None)
         for q in tokenized_questions
     ]
     return self.encode_questions(
         samples,
         return_search_vectors=return_search_vectors,
         show_progress=show_progress)
コード例 #5
0
 def _sample_false_1(self, qid):
     """ False sample of type 1: all distractors.
     No sampling from other question here, as I think it's less effective in this case"""
     question = self.qid2question[qid]
     two_distractors = self.random.choice(question.distractors,
                                          size=2,
                                          replace=False)
     pars = [flatten_iterable(x.sentences) for x in two_distractors]
     segs = [
         get_segments_from_sentences(x.sentences) for x in two_distractors
     ]
     return IterativeQuestionAndParagraphs(
         question=question.question_tokens,
         paragraphs=pars,
         first_label=0,
         second_label=0,
         question_id=question.question_id,
         q_type=question.q_type,
         sentence_segments=segs)
コード例 #6
0
    def _sample_first_gold_second_false(self, qid):
        question = self.qid2question[qid]
        rand_par_other_q = self._sample_rand_par_other_q(qid)
        if question.q_type == 'comparison' or self.bridge_as_comparison:
            first_gold_par = question.supporting_facts[self.random.randint(2)]
        else:
            if not self.label_by_span:
                first_gold_idx = 0 if question.gold_scores[
                    0] > question.gold_scores[1] else 1
                first_gold_par = question.supporting_facts[first_gold_idx]
            else:
                gold_idxs = self._get_no_span_containing_golds(
                    question.question_id)
                if len(gold_idxs) == 1:
                    first_gold_par = question.supporting_facts[gold_idxs[0]]
                else:
                    first_gold_par = question.supporting_facts[
                        self.random.randint(2)]

        rand_par = self.random.choice([
            rand_par_other_q, first_gold_par,
            self.random.choice(question.distractors)
        ],
                                      p=[0.05, 0.1, 0.85])

        pars = [
            flatten_iterable(first_gold_par.sentences),
            flatten_iterable(rand_par.sentences)
        ]
        segs = [
            get_segments_from_sentences(first_gold_par.sentences),
            get_segments_from_sentences(rand_par.sentences)
        ]
        return IterativeQuestionAndParagraphs(
            question=question.question_tokens,
            paragraphs=pars,
            first_label=1,
            second_label=0,
            question_id=question.question_id,
            q_type=question.q_type,
            sentence_segments=segs)
コード例 #7
0
 def _sample_false_2(self, qid):
     """ False sample of type 2: first distractor, second one of supporting facts """
     question = self.qid2question[qid]
     rand_par_other_q = self._sample_rand_par_other_q(qid)
     distractor = self.random.choice(
         [self.random.choice(question.distractors), rand_par_other_q],
         p=[0.9, 0.1])
     gold = self.random.choice(question.supporting_facts)
     pars = [flatten_iterable(x.sentences) for x in [distractor, gold]]
     segs = [
         get_segments_from_sentences(x.sentences)
         for x in [distractor, gold]
     ]
     return IterativeQuestionAndParagraphs(
         question=question.question_tokens,
         paragraphs=pars,
         first_label=0,
         second_label=0,
         question_id=question.question_id,
         q_type=question.q_type,
         sentence_segments=segs)
コード例 #8
0
 def _build_gold_samples(self):
     gold_samples = []
     for question in self.questions:
         if question.q_type == 'comparison' or self.bridge_as_comparison:
             pars_order = [0, 1]
             self.random.shuffle(pars_order)
         else:
             if not self.label_by_span:
                 pars_order = [0, 1] if question.gold_scores[
                     0] > question.gold_scores[1] else [1, 0]
             else:
                 gold_idxs = self._get_no_span_containing_golds(
                     question.question_id)
                 pars_order = [0, 1] if 0 in gold_idxs else [1, 0]
                 if len(
                         gold_idxs
                 ) != 1:  # either both contain the answer or both don't contain, so regarded equal
                     self.random.shuffle(pars_order)
         pars = [
             flatten_iterable(question.supporting_facts[i].sentences)
             for i in pars_order
         ]
         sentence_segs = [
             get_segments_from_sentences(
                 question.supporting_facts[i].sentences) for i in pars_order
         ]
         gold_samples.append(
             IterativeQuestionAndParagraphs(
                 question.question_tokens,
                 pars,
                 first_label=1,
                 second_label=1,
                 question_id=question.question_id,
                 q_type=question.q_type,
                 sentence_segments=sentence_segs))
     return gold_samples
コード例 #9
0
def encode_from_file(docs_file,
                     questions_file,
                     encodings_dir,
                     encoder_model,
                     num_workers,
                     hotpot: bool,
                     long_batch: int,
                     short_batch: int,
                     use_chars: bool,
                     use_ema: bool,
                     checkpoint: str,
                     document_chunk_size=1000,
                     samples=None,
                     encode_all_db=False):
    """

    :param out_file: .npz file to dump the encodings
    :param docs_file: path to json file whose structure is [{title: list of paragraphs}, ...]
    :return:
    """
    doc_encs_handler = DocumentEncodingHandler(encodings_dir)
    # Setup worker pool
    workers = ProcessPool(num_workers, initializer=init, initargs=[])

    if docs_file is not None:
        with open(docs_file, 'r') as f:
            documents = json.load(f)
        documents = {
            k: v
            for k, v in documents.items()
            if k not in doc_encs_handler.titles2filenames
        }

        tokenized_documents = {}
        tupled_doc_list = [(title, pars) for title, pars in documents.items()]

        if samples is not None:
            print(f"sampling {samples} samples")
            tupled_doc_list = tupled_doc_list[:samples]

        print("Tokenizing from file...")
        with tqdm(total=len(tupled_doc_list), ncols=80) as pbar:
            for tok_doc in tqdm(
                    workers.imap_unordered(tokenize_document,
                                           tupled_doc_list)):
                tokenized_documents.update(tok_doc)
                pbar.update()
    else:
        if questions_file is not None:
            with open(questions_file, 'r') as f:
                questions = json.load(f)
            all_titles = list(
                set([title for q in questions for title in q['top_titles']]))
        else:
            print("encoding all DB!")
            all_titles = DocDB().get_doc_titles()

        if samples is not None:
            print(f"sampling {samples} samples")
            all_titles = all_titles[:samples]

        all_titles = [
            t for t in all_titles if t not in doc_encs_handler.titles2filenames
        ]
        tokenized_documents = {}

        print("Tokenizing from DB...")
        with tqdm(total=len(all_titles), ncols=80) as pbar:
            for tok_doc in tqdm(
                    workers.imap_unordered(tokenize_from_db, all_titles)):
                tokenized_documents.update(tok_doc)
                pbar.update()

    workers.close()
    workers.join()

    voc = set()
    for paragraphs in tokenized_documents.values():
        for par in paragraphs:
            voc.update(par)

    if not hotpot:
        spec = QuestionAndParagraphsSpec(batch_size=None,
                                         max_num_contexts=1,
                                         max_num_question_words=None,
                                         max_num_context_words=None)
        encoder = SentenceEncoderSingleContext(model_dir_path=encoder_model,
                                               vocabulary=voc,
                                               spec=spec,
                                               loader=ResourceLoader(),
                                               use_char_inputs=use_chars,
                                               use_ema=use_ema,
                                               checkpoint=checkpoint)
    else:
        spec = QuestionAndParagraphsSpec(batch_size=None,
                                         max_num_contexts=2,
                                         max_num_question_words=None,
                                         max_num_context_words=None)
        encoder = SentenceEncoderIterativeModel(model_dir_path=encoder_model,
                                                vocabulary=voc,
                                                spec=spec,
                                                loader=ResourceLoader(),
                                                use_char_inputs=use_chars,
                                                use_ema=use_ema,
                                                checkpoint=checkpoint)

    tokenized_documents_items = list(tokenized_documents.items())
    for tokenized_doc_chunk in tqdm([
            tokenized_documents_items[i:i + document_chunk_size] for i in
            range(0, len(tokenized_documents_items), document_chunk_size)
    ],
                                    ncols=80):
        flattened_pars_with_names = [(f"{title}_{i}", par)
                                     for title, pars in tokenized_doc_chunk
                                     for i, par in enumerate(pars)]

        # filtering out empty paragraphs (probably had some short string the tokenization removed)
        # important to notice that the filtered paragraphs will have no representation,
        # but they still exist in the numbering of paragraphs for consistency with the docs.
        flattened_pars_with_names = [(name, par)
                                     for name, par in flattened_pars_with_names
                                     if len(par) > 0]

        # sort such that longer paragraphs are first to identify OOMs early on
        flattened_pars_with_names = sorted(flattened_pars_with_names,
                                           key=lambda x: len(x[1]),
                                           reverse=True)
        long_paragraphs_ids = [
            i for i, name_par in enumerate(flattened_pars_with_names)
            if len(name_par[1]) >= 900
        ]
        short_paragraphs_ids = [
            i for i, name_par in enumerate(flattened_pars_with_names)
            if len(name_par[1]) < 900
        ]

        # print(f"Encoding {len(flattened_pars_with_names)} paragraphs...")
        name2enc = {}
        dummy_question = "Hello Hello".split()
        if not hotpot:
            model_paragraphs = [
                BinaryQuestionAndParagraphs(question=dummy_question,
                                            paragraphs=[x],
                                            label=1,
                                            num_distractors=0,
                                            question_id='dummy')
                for _, x in flattened_pars_with_names
            ]
        else:
            # todo allow precomputed sentence segments
            model_paragraphs = [
                IterativeQuestionAndParagraphs(question=dummy_question,
                                               paragraphs=[x, dummy_question],
                                               first_label=1,
                                               second_label=1,
                                               question_id='dummy',
                                               sentence_segments=None)
                for _, x in flattened_pars_with_names
            ]

        # print("Encoding long paragraphs...")
        long_pars = [model_paragraphs[i] for i in long_paragraphs_ids]
        name2enc.update({
            flattened_pars_with_names[long_paragraphs_ids[i]][0]: enc
            for i, enc in enumerate(
                encoder.encode_paragraphs(
                    long_pars, batch_size=long_batch, show_progress=True
                ) if not hotpot else encoder.encode_first_paragraphs(
                    long_pars, batch_size=long_batch, show_progress=True))
        })

        # print("Encoding short paragraphs...")
        short_pars = [model_paragraphs[i] for i in short_paragraphs_ids]
        name2enc.update({
            flattened_pars_with_names[short_paragraphs_ids[i]][0]: enc
            for i, enc in enumerate(
                encoder.encode_paragraphs(
                    short_pars, batch_size=short_batch, show_progress=True
                ) if not hotpot else encoder.encode_first_paragraphs(
                    short_pars, batch_size=short_batch, show_progress=True))
        })

        doc_encs_handler.save_multiple_documents(name2enc)