def test_answer_unclosed_span(): with pytest.raises(AssertionError): MachineReaderAnswer(text="test", span=(1, None), long_text="long_test", long_text_span=(3, 9), score_reader=0.5, score_answer_in_document=0.1)
def test_answer_invalid_score_answer_in_document(): with pytest.raises(AssertionError): MachineReaderAnswer(text="test", span=(1, 5), long_text="long_test", long_text_span=(3, 9), score_reader=0.5, score_answer_in_document=20)
def test_answer_creation(): a = MachineReaderAnswer(text="test", span=(1, 5), long_text="long_test", long_text_span=(3, 9), score_reader=0.5, score_answer_in_document=0.1) assert a.text == "test" assert a.span == (1, 5) assert a.long_text == "long_test" assert a.long_text_span == (3, 9) assert a.score_reader == 0.5 assert a.score_answer_in_document == 0.1
def test_answer_creation_defaults(): with pytest.raises(TypeError): MachineReaderAnswer(text='test', score_reader=0.5, score_answer_in_document=0.1)
def get_answers_from_logits( self, configuration: MachineReaderConfiguration, all_the_logits: List[Tuple[np.array, np.array]], all_the_overlaps: List[Tuple[int, int]], all_combined_texts: str, ) -> Iterable[MachineReaderAnswer]: """Combine logit distributions from several documents and generate the highest scoring answers :param configuration: configuration object to control how answers are produced :param all_the_logits: list of (start_logit_scores, end_logit_scores) for the documents :param all_the_overlaps: list of (start token index, end token index) for where the begin_overlap and end_overap strings start in each document :param all_combined_texts: all the document strings as a single big string :return: iterable of machine reader answer objects """ if len(all_the_logits) == 0: raise MachineReaderError('Need at least one block of logits') if len(all_the_overlaps) == 0: raise MachineReaderError('Need at least one block of overlaps') if len(all_the_overlaps) != len(all_the_logits): raise MachineReaderError( 'Overlaps and logits need to be the same length') logits_array_start = np.concatenate([ logits[overlap_start:len(logits) - overlap_end] for (logits, _), (overlap_start, overlap_end) in zip(all_the_logits, all_the_overlaps) ]) logits_array_end = np.concatenate([ logits[overlap_start:len(logits) - overlap_end] for (_, logits), ( overlap_start, overlap_end) in zip(all_the_logits, all_the_overlaps) ]) if len(logits_array_start) != self._count_tokens(all_combined_texts): raise MachineReaderError('logits length mismatch {} {}'.format( len(logits_array_start), self._count_tokens(all_combined_texts))) # Perform global softmax yp_start, yp_end = softmax(logits_array_start), softmax( logits_array_end) context_tokens, context_offsets = self.model.tokenize( all_combined_texts) answer_spans = find_best_spans(all_combined_texts, context_offsets, yp_start, yp_end, configuration.top_k) for answer_span in answer_spans: score_answer_in_document = 0. l1 = logits_array_start[answer_span.word_indices[0]] l2 = logits_array_end[answer_span.word_indices[1]] unnorm_score = l1 + l2 if (answer_span.score >= configuration.threshold_reader and score_answer_in_document >= configuration.threshold_answer_in_document): yield MachineReaderAnswer( text=answer_span.answer_text, span=answer_span.character_indices, long_text=answer_span.long_answer_text, long_text_span=answer_span.long_character_indices, score_reader=answer_span.score, score_answer_in_document=score_answer_in_document) else: break