def test_answer_unclosed_span():
    with pytest.raises(AssertionError):
        MachineReaderAnswer(text="test",
                            span=(1, None),
                            long_text="long_test",
                            long_text_span=(3, 9),
                            score_reader=0.5,
                            score_answer_in_document=0.1)
def test_answer_invalid_score_answer_in_document():
    with pytest.raises(AssertionError):
        MachineReaderAnswer(text="test",
                            span=(1, 5),
                            long_text="long_test",
                            long_text_span=(3, 9),
                            score_reader=0.5,
                            score_answer_in_document=20)
def test_answer_creation():
    a = MachineReaderAnswer(text="test",
                            span=(1, 5),
                            long_text="long_test",
                            long_text_span=(3, 9),
                            score_reader=0.5,
                            score_answer_in_document=0.1)
    assert a.text == "test"
    assert a.span == (1, 5)
    assert a.long_text == "long_test"
    assert a.long_text_span == (3, 9)
    assert a.score_reader == 0.5
    assert a.score_answer_in_document == 0.1
def test_answer_creation_defaults():
    with pytest.raises(TypeError):
        MachineReaderAnswer(text='test',
                            score_reader=0.5,
                            score_answer_in_document=0.1)
Exemple #5
0
    def get_answers_from_logits(
        self,
        configuration: MachineReaderConfiguration,
        all_the_logits: List[Tuple[np.array, np.array]],
        all_the_overlaps: List[Tuple[int, int]],
        all_combined_texts: str,
    ) -> Iterable[MachineReaderAnswer]:
        """Combine logit distributions from several documents and generate the highest scoring answers

        :param configuration: configuration object to control how answers are produced
        :param all_the_logits: list of (start_logit_scores, end_logit_scores) for the documents
        :param all_the_overlaps: list of (start token index, end token index) for where the begin_overlap
            and end_overap strings start in each document
        :param all_combined_texts: all the document strings as a single big string
        :return: iterable of machine reader answer objects
        """
        if len(all_the_logits) == 0:
            raise MachineReaderError('Need at least one block of logits')
        if len(all_the_overlaps) == 0:
            raise MachineReaderError('Need at least one block of overlaps')
        if len(all_the_overlaps) != len(all_the_logits):
            raise MachineReaderError(
                'Overlaps and logits need to be the same length')

        logits_array_start = np.concatenate([
            logits[overlap_start:len(logits) - overlap_end]
            for (logits,
                 _), (overlap_start,
                      overlap_end) in zip(all_the_logits, all_the_overlaps)
        ])
        logits_array_end = np.concatenate([
            logits[overlap_start:len(logits) - overlap_end]
            for (_, logits), (
                overlap_start,
                overlap_end) in zip(all_the_logits, all_the_overlaps)
        ])
        if len(logits_array_start) != self._count_tokens(all_combined_texts):
            raise MachineReaderError('logits length mismatch {} {}'.format(
                len(logits_array_start),
                self._count_tokens(all_combined_texts)))

        # Perform global softmax
        yp_start, yp_end = softmax(logits_array_start), softmax(
            logits_array_end)

        context_tokens, context_offsets = self.model.tokenize(
            all_combined_texts)
        answer_spans = find_best_spans(all_combined_texts, context_offsets,
                                       yp_start, yp_end, configuration.top_k)

        for answer_span in answer_spans:
            score_answer_in_document = 0.
            l1 = logits_array_start[answer_span.word_indices[0]]
            l2 = logits_array_end[answer_span.word_indices[1]]
            unnorm_score = l1 + l2

            if (answer_span.score >= configuration.threshold_reader
                    and score_answer_in_document >=
                    configuration.threshold_answer_in_document):
                yield MachineReaderAnswer(
                    text=answer_span.answer_text,
                    span=answer_span.character_indices,
                    long_text=answer_span.long_answer_text,
                    long_text_span=answer_span.long_character_indices,
                    score_reader=answer_span.score,
                    score_answer_in_document=score_answer_in_document)
            else:
                break