def squad_span_scores(data: List[ContextAndQuestion], prediction):
    scores = np.zeros((len(data), 4))
    for i in range(len(data)):
        para = data[i]

        pred_span = tuple(prediction[i])
        # For SQuAD, we expect to be working with data points that know how to
        # retrieve the untokenized "raw" text each span is associated with
        pred_text = para.get_original_text(pred_span[0], pred_span[1])

        span_correct = False
        span_max_f1 = 0
        text_correct = 0
        text_max_f1 = 0
        answer = data[i].answer
        for (start, end), text in zip(answer.answer_spans, answer.answer_text):
            answer_span = (start, end)
            span_max_f1 = max(span_max_f1, compute_span_f1(answer_span, pred_span))
            if answer_span == pred_span:
                span_correct = True
            f1 = squad_official_f1_score(pred_text, text)
            correct = squad_official_em_score(pred_text, text)
            text_correct = max(text_correct, correct)
            text_max_f1 = max(text_max_f1, f1)

        scores[i] = [span_correct, span_max_f1, text_correct, text_max_f1]

    return scores
def trivia_span_scores(data: List[ContextAndQuestion],
                       prediction):
    scores = np.zeros((len(data), 4))
    for i in range(len(data)):
        para = data[i]
        ans = para.answer

        pred_span = prediction[i]
        # For TriviaQA we have generally called join-on-spaces approach good enough, since the answers here
        # tend to be short and the gold standard has better normalization. Possibly could get a very
        # small gain using the original text
        pred_text = " ".join(para.get_context()[pred_span[0]:pred_span[1]+1])

        span_correct = False
        span_max_f1 = 0
        text_correct = 0
        text_max_f1 = 0

        for word_start, word_end in ans.answer_spans:
            answer_span = (word_start, word_end)
            span_max_f1 = max(span_max_f1, compute_span_f1(answer_span, pred_span))
            if answer_span == tuple(pred_span):
                span_correct = True

        for text in ans.answer_text:
            f1 = triviaqa_f1_score(pred_text, text)
            correct = triviaqa_em_score(pred_text, text)
            text_correct = max(text_correct, correct)
            text_max_f1 = max(text_max_f1, f1)

        scores[i] = [span_correct, span_max_f1, text_correct, text_max_f1]
    return scores
def span_scores(data: List[ContextAndQuestion], prediction):
    scores = np.zeros((len(data), 2))
    for i in range(len(data)):
        pred_span = tuple(prediction[i])

        span_correct = False
        span_max_f1 = 0
        answer = data[i].answer
        for (start, end) in answer.answer_spans:
            answer_span = (start, end)
            span_max_f1 = max(span_max_f1, compute_span_f1(answer_span, pred_span))
            if answer_span == pred_span:
                span_correct = True

        scores[i] = [span_correct, span_max_f1]

    return scores