예제 #1
0
    def _get_best_prediction(self, start_logits, end_logits, relevance_logits,
                             samples_batch: List[ReaderSample], passage_thresholds: List[int] = None) \
            -> List[ReaderQuestionPredictions]:

        args = self.args
        max_answer_length = args.max_answer_length
        questions_num, passages_per_question = relevance_logits.size()

        _, idxs = torch.sort(relevance_logits, dim=1, descending=True, )

        batch_results = []
        for q in range(questions_num):
            sample = samples_batch[q]

            non_empty_passages_num = len(sample.passages)
            nbest = []
            for p in range(passages_per_question):
                passage_idx = idxs[q, p].item()
                if passage_idx >= non_empty_passages_num:  # empty passage selected, skip
                    continue
                reader_passage = sample.passages[passage_idx]
                sequence_ids = reader_passage.sequence_ids
                sequence_len = sequence_ids.size(0)
                # assuming question & title information is at the beginning of the sequence
                passage_offset = reader_passage.passage_offset

                p_start_logits = start_logits[q, passage_idx].tolist()[passage_offset:sequence_len]
                p_end_logits = end_logits[q, passage_idx].tolist()[passage_offset:sequence_len]

                ctx_ids = sequence_ids.tolist()[passage_offset:]
                best_spans = get_best_spans(self.tensorizer, p_start_logits, p_end_logits, ctx_ids, max_answer_length,
                                            passage_idx, relevance_logits[q, passage_idx].item(), top_spans=10)
                nbest.extend(best_spans)
                if len(nbest) > 0 and not passage_thresholds:
                    break

            if passage_thresholds:
                passage_rank_matches = {}
                for n in passage_thresholds:
                    curr_nbest = [pred for pred in nbest if pred.passage_index < n]
                    passage_rank_matches[n] = curr_nbest[0]
                predictions = passage_rank_matches
            else:
                if len(nbest) == 0:
                    predictions = {passages_per_question: SpanPrediction('', -1, -1, -1, '')}
                else:
                    predictions = {passages_per_question: nbest[0]}
            batch_results.append(ReaderQuestionPredictions(sample.question, predictions, sample.answers))
        return batch_results
예제 #2
0
    def _get_best_prediction(self, start_logits, end_logits, relevance_logits,
                             samples_batch: List[ReaderSample], passage_thresholds: List[int] = None) \
            -> List[ReaderQuestionPredictions]:

        args = self.args
        max_answer_length = args.max_answer_length
        questions_num, passages_per_question = relevance_logits.size()

        _, idxs = torch.sort(
            relevance_logits,
            dim=1,
            descending=True,
        )

        batch_results = []
        for q in range(questions_num):
            sample = samples_batch[q]

            non_empty_passages_num = len(sample.passages)
            nbest = []
            for p in range(passages_per_question):
                passage_idx = idxs[q, p].item()
                if passage_idx >= non_empty_passages_num:  # empty passage selected, skip
                    continue
                reader_passage = sample.passages[passage_idx]
                sequence_ids = reader_passage.sequence_ids
                sequence_len = sequence_ids.size(0)
                # assuming question & title information is at the beginning of the sequence
                passage_offset = reader_passage.passage_offset

                p_start_logits = start_logits[
                    q, passage_idx].tolist()[passage_offset:sequence_len]
                p_end_logits = end_logits[
                    q, passage_idx].tolist()[passage_offset:sequence_len]

                ctx_ids = sequence_ids.tolist()[passage_offset:]

                context, tok_to_orig_index = reader_passage.passage_text, reader_passage.tok_to_orig_index
                best_spans = get_best_spans(
                    self.tensorizer,
                    p_start_logits,
                    p_end_logits,
                    ctx_ids,
                    max_answer_length,
                    passage_idx,
                    relevance_logits[q, passage_idx].item(),
                    context,
                    tok_to_orig_index,
                    top_spans=10)

                nbest.extend(best_spans)
                if len(nbest) > 0 and not passage_thresholds:
                    break

            if args.rank_method == 'span':
                nbest.sort(key=lambda x: x.span_score, reverse=True)
            elif args.rank_method == 'rel+span':
                nbest.sort(key=lambda x: x.span_score + x.relevance_score,
                           reverse=True)

            if passage_thresholds:
                passage_rank_matches = {}
                for n in passage_thresholds:
                    curr_nbest = [
                        pred for pred in nbest if pred.passage_index < n
                    ]
                    try:
                        passage_rank_matches[n] = curr_nbest[0]
                    except:
                        print("No answer for {}".format(sample.question))
                        passage_rank_matches[n] = SpanPrediction(
                            '', -1, -1, -1, '')
                predictions = passage_rank_matches
            else:
                if len(nbest) == 0:
                    predictions = {
                        passages_per_question:
                        SpanPrediction('', -1, -1, -1, '')
                    }
                else:
                    predictions = {passages_per_question: nbest[0]}

            has_answer = [p.has_answer for p in sample.passages]
            batch_results.append(
                ReaderQuestionPredictions(sample.question, sample.question_id,
                                          predictions, sample.answers,
                                          has_answer))
        return batch_results
예제 #3
0
    def _get_best_prediction(
        self,
        start_logits,
        end_logits,
        relevance_logits,
        samples_batch: List[ReaderSample],
        passage_thresholds: List[int] = None,
    ) -> List[ReaderQuestionPredictions]:

        args = self.args
        max_answer_length = args.max_answer_length
        questions_num, passages_per_question = relevance_logits.size()

        _, idxs = torch.sort(
            relevance_logits,
            dim=1,
            descending=True,
        )

        batch_results = []
        for q in range(questions_num):
            sample = samples_batch[q]

            non_empty_passages_num = len(sample.passages)
            nbest = []
            for p in range(passages_per_question):
                passage_idx = idxs[q, p].item()
                if (
                    passage_idx >= non_empty_passages_num
                ):  # empty passage selected, skip
                    continue
                reader_passage = sample.passages[passage_idx]
                sequence_ids = reader_passage.sequence_ids
                sequence_len = sequence_ids.size(0)
                # assuming question & title information is at the beginning of the sequence
                passage_offset = reader_passage.passage_offset

                p_start_logits = start_logits[q, passage_idx].tolist()[
                    passage_offset:sequence_len
                ]
                p_end_logits = end_logits[q, passage_idx].tolist()[
                    passage_offset:sequence_len
                ]

                ctx_ids = sequence_ids.tolist()[passage_offset:]
                best_spans = get_best_spans(
                    self.tensorizer,
                    p_start_logits,
                    p_end_logits,
                    ctx_ids,
                    max_answer_length,
                    passage_idx,
                    relevance_logits[q, passage_idx].item(),
                    top_spans=10,
                )
                nbest.extend(best_spans)
                if False and len(nbest) > 0 and not passage_thresholds:
                    break

            #if passage_thresholds:
            #    passage_rank_matches = {}
            #    for n in passage_thresholds:

            # TODO:
            # TODO:
            # TODO:
            # TODO:
            # TODO:
            # TODO:
            # TODO:
            # TODO: Consider changing scoring function to softmax earlier (incl. over multiple paragraphs)
            # TODO:
            # TODO:
            # TODO:
            # TODO: casing
            
            # Softmax all scores
            scores = []
            for pred in nbest:
                scores.append(pred.span_score)
            smax_scores = F.softmax(torch.Tensor(scores)).tolist()
            for i in range(len(nbest)):
                pred = nbest[i]
                nbest[i] = pred._replace(span_score = smax_scores.pop(0))
            curr_nbest_dict = {}

            # Add duplicates
            for pred in nbest:
                if pred.prediction_text in curr_nbest_dict.keys():
                    curr_nbest_dict[pred.prediction_text] = curr_nbest_dict[pred.prediction_text]._replace(span_score = pred.span_score + curr_nbest_dict[pred.prediction_text].span_score) # Convoluted thing to just add the two span scores
                else:
                    curr_nbest_dict[pred.prediction_text] = pred

            curr_nbest = sorted(curr_nbest_dict.values(), key=lambda x: -x.span_score)
            #        passage_rank_matches[n] = curr_nbest[0]
            #    predictions = passage_rank_matches
            #else:
            #    if len(nbest) == 0:
            #        predictions = {
            #            passages_per_question: SpanPrediction("", -1, -1, -1, "")
            #        }
            #    else:
            #        predictions = {passages_per_question: nbest[0]}
            predictions = {passages_per_question: curr_nbest}
            batch_results.append(
                ReaderQuestionPredictions(sample.question, predictions, sample.answers)
            )
        return batch_results
예제 #4
0
def get_best_prediction(
    max_answer_length: int,
    tensorizer,
    start_logits,
    end_logits,
    relevance_logits,
    samples_batch: List[ReaderSample],
    passage_thresholds: List[int] = None,
) -> List[ReaderQuestionPredictions]:

    questions_num, passages_per_question = relevance_logits.size()

    _, idxs = torch.sort(
        relevance_logits,
        dim=1,
        descending=True,
    )

    batch_results = []
    for q in range(questions_num):
        sample = samples_batch[q]

        # Need to re-sort samples based on their scores; see `create_reader_input` function
        all_passages = sample.positive_passages + sample.negative_passages
        all_passages = sorted(all_passages, key=lambda x: x.score, reverse=True)

        non_empty_passages_num = len(all_passages)
        nbest: List[SpanPrediction] = []
        for p in range(passages_per_question):
            passage_idx = idxs[q, p].item()
            if (
                passage_idx >= non_empty_passages_num
            ):  # empty passage selected, skip
                continue
            reader_passage = all_passages[passage_idx]
            sequence_ids = reader_passage.sequence_ids
            sequence_len = sequence_ids.size(0)
            # assuming question & title information is at the beginning of the sequence
            passage_offset = reader_passage.passage_offset

            p_start_logits = start_logits[q, passage_idx].tolist()[
                passage_offset:sequence_len
            ]
            p_end_logits = end_logits[q, passage_idx].tolist()[
                passage_offset:sequence_len
            ]

            ctx_ids = sequence_ids.tolist()[passage_offset:]
            best_spans = get_best_spans(
                tensorizer,
                p_start_logits,
                p_end_logits,
                ctx_ids,
                max_answer_length,
                passage_idx,
                relevance_logits[q, passage_idx].item(),
                top_spans=10,
            )
            nbest.extend(best_spans)
            if len(nbest) > 0 and not passage_thresholds:
                break

        if passage_thresholds:
            passage_rank_matches = {}
            for n in passage_thresholds:
                # By this, it only selects
                curr_nbest = [pred for pred in nbest if pred.passage_index < n]
                passage_rank_matches[n] = curr_nbest[0]
            predictions = passage_rank_matches
        else:
            if len(nbest) == 0:
                predictions = {
                    passages_per_question: SpanPrediction("", -1, -1, -1, "")
                }
            else:
                predictions = {passages_per_question: nbest[0]}
        batch_results.append(
            ReaderQuestionPredictions(sample.question, predictions, sample.answers)
        )
    return batch_results