Exemplo n.º 1
0
    def test_get_best_span(self):
        span_begin_probs = torch.FloatTensor([[0.1, 0.3, 0.05, 0.3, 0.25]]).log()
        span_end_probs = torch.FloatTensor([[0.65, 0.05, 0.2, 0.05, 0.05]]).log()
        begin_end_idxs = get_best_span(span_begin_probs, span_end_probs)
        assert_almost_equal(begin_end_idxs.data.numpy(), [[0, 0]])

        # When we were using exclusive span ends, this was an edge case of the dynamic program.
        # We're keeping the test to make sure we get it right now, after the switch in inclusive
        # span end.  The best answer is (1, 1).
        span_begin_probs = torch.FloatTensor([[0.4, 0.5, 0.1]]).log()
        span_end_probs = torch.FloatTensor([[0.3, 0.6, 0.1]]).log()
        begin_end_idxs = get_best_span(span_begin_probs, span_end_probs)
        assert_almost_equal(begin_end_idxs.data.numpy(), [[1, 1]])

        # Another instance that used to be an edge case.
        span_begin_probs = torch.FloatTensor([[0.8, 0.1, 0.1]]).log()
        span_end_probs = torch.FloatTensor([[0.8, 0.1, 0.1]]).log()
        begin_end_idxs = get_best_span(span_begin_probs, span_end_probs)
        assert_almost_equal(begin_end_idxs.data.numpy(), [[0, 0]])

        span_begin_probs = torch.FloatTensor([[0.1, 0.2, 0.05, 0.3, 0.25]]).log()
        span_end_probs = torch.FloatTensor([[0.1, 0.2, 0.5, 0.05, 0.15]]).log()
        begin_end_idxs = get_best_span(span_begin_probs, span_end_probs)
        assert_almost_equal(begin_end_idxs.data.numpy(), [[1, 2]])
Exemplo n.º 2
0
def ensemble(subresults: List[Dict[str, torch.Tensor]]) -> torch.Tensor:
    """
    Identifies the best prediction given the results from the submodels.

    Parameters
    ----------
    subresults : List[Dict[str, torch.Tensor]]
        Results of each submodel.

    Returns
    -------
    The index of the best submodel.
    """

    # Choose the highest average confidence span.

    span_start_probs = sum(subresult["span_start_probs"] for subresult in subresults) / len(
        subresults
    )
    span_end_probs = sum(subresult["span_end_probs"] for subresult in subresults) / len(subresults)
    return get_best_span(span_start_probs.log(), span_end_probs.log())  # type: ignore
Exemplo n.º 3
0
    def forward(  # type: ignore
        self,
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        span_start: torch.IntTensor = None,
        span_end: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question : Dict[str, torch.LongTensor]
            From a ``TextField``.
        passage : Dict[str, torch.LongTensor]
            From a ``TextField``.  The model assumes that this passage contains the answer to the
            question, and predicts the beginning and ending positions of the answer within the
            passage.
        span_start : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            beginning position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        span_end : ``torch.IntTensor``, optional
            From an ``IndexField``.  This is one of the things we are trying to predict - the
            ending position of the answer with the passage.  This is an `inclusive` token index.
            If this is given, we will compute a loss that gets included in the output dictionary.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question tokens, passage tokens, original passage
            text, and token offsets into the passage for each instance in the batch.  The length
            of this list should be the batch size, and each dictionary should have the keys
            ``question_tokens``, ``passage_tokens``, ``original_passage``, and ``token_offsets``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        question_mask = util.get_text_field_mask(question)
        passage_mask = util.get_text_field_mask(passage)

        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(
            self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(
            self._embedding_proj_layer(embedded_passage))

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(
            embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(
            embedded_passage)

        encoded_question = self._dropout(
            self._phrase_layer(projected_embedded_question, question_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(projected_embedded_passage, passage_mask))

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(
            passage_question_similarity, question_mask, memory_efficient=True)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(
            passage_question_similarity.transpose(1, 2),
            passage_mask,
            memory_efficient=True)
        # Shape: (batch_size, passage_length, passage_length)
        attention_over_attention = torch.bmm(passage_question_attention,
                                             question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(encoded_passage,
                                                    attention_over_attention)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
            torch.cat(
                [
                    encoded_passage,
                    passage_question_vectors,
                    encoded_passage * passage_question_vectors,
                    encoded_passage * passage_passage_vectors,
                ],
                dim=-1,
            ))

        modeled_passage_list = [
            self._modeling_proj_layer(merged_passage_attention_vectors)
        ]

        for _ in range(3):
            modeled_passage = self._dropout(
                self._modeling_layer(modeled_passage_list[-1], passage_mask))
            modeled_passage_list.append(modeled_passage)

        # Shape: (batch_size, passage_length, modeling_dim * 2))
        span_start_input = torch.cat(
            [modeled_passage_list[-3], modeled_passage_list[-2]], dim=-1)
        # Shape: (batch_size, passage_length)
        span_start_logits = self._span_start_predictor(
            span_start_input).squeeze(-1)

        # Shape: (batch_size, passage_length, modeling_dim * 2)
        span_end_input = torch.cat(
            [modeled_passage_list[-3], modeled_passage_list[-1]], dim=-1)
        span_end_logits = self._span_end_predictor(span_end_input).squeeze(-1)
        span_start_logits = replace_masked_values_with_big_negative_number(
            span_start_logits, passage_mask)
        span_end_logits = replace_masked_values_with_big_negative_number(
            span_end_logits, passage_mask)

        # Shape: (batch_size, passage_length)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
                                                       dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)

        best_span = get_best_span(span_start_logits, span_end_logits)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_span,
        }

        # Compute the loss for training.
        if span_start is not None:
            loss = nll_loss(
                util.masked_log_softmax(span_start_logits, passage_mask),
                span_start.squeeze(-1))
            self._span_start_accuracy(span_start_logits,
                                      span_start.squeeze(-1))
            loss += nll_loss(
                util.masked_log_softmax(span_end_logits, passage_mask),
                span_end.squeeze(-1))
            self._span_end_accuracy(span_end_logits, span_end.squeeze(-1))
            self._span_accuracy(best_span, torch.cat([span_start, span_end],
                                                     -1))
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            output_dict["best_span_str"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]["question_tokens"])
                passage_tokens.append(metadata[i]["passage_tokens"])
                passage_str = metadata[i]["original_passage"]
                offsets = metadata[i]["token_offsets"]
                predicted_span = tuple(best_span[i].detach().cpu().numpy())
                start_offset = offsets[predicted_span[0]][0]
                end_offset = offsets[predicted_span[1]][1]
                best_span_string = passage_str[start_offset:end_offset]
                output_dict["best_span_str"].append(best_span_string)
                answer_texts = metadata[i].get("answer_texts", [])
                if answer_texts:
                    self._metrics(best_span_string, answer_texts)
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens
        return output_dict
Exemplo n.º 4
0
    def forward(  # type: ignore
        self,
        question_with_context: Dict[str, Dict[str, torch.LongTensor]],
        context_span: torch.IntTensor,
        cls_index: torch.LongTensor = None,
        answer_span: torch.IntTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        question_with_context : `Dict[str, torch.LongTensor]`
            From a `TextField`. The model assumes that this text field contains the context followed by the
            question. It further assumes that the tokens have type ids set such that any token that can be part of
            the answer (i.e., tokens from the context) has type id 0, and any other token (including
            `[CLS]` and `[SEP]`) has type id 1.

        context_span : `torch.IntTensor`
            From a `SpanField`. This marks the span of word pieces in `question` from which answers can come.

        cls_index : `torch.LongTensor`, optional
            A tensor of shape `(batch_size,)` that provides the index of the `[CLS]` token
            in the `question_with_context` for each instance.

            This is needed because the `[CLS]` token is used to indicate that the question
            is impossible.

            If this is `None`, it's assumed that the `[CLS]` token is at index 0 for each instance
            in the batch.

        answer_span : `torch.IntTensor`, optional
            From a `SpanField`. This is the thing we are trying to predict - the span of text that marks the
            answer. If given, we compute a loss that gets included in the output directory.

        metadata : `List[Dict[str, Any]]`, optional
            If present, this should contain the question id, and the original texts of context, question, tokenized
            version of both, and a list of possible answers. The length of the `metadata` list should be the
            batch size, and each dictionary should have the keys `id`, `question`, `context`,
            `question_tokens`, `context_tokens`, and `answers`.

        # Returns

        `Dict[str, torch.Tensor]` :
            An output dictionary with the following fields:

            - span_start_logits (`torch.FloatTensor`) :
              A tensor of shape `(batch_size, passage_length)` representing unnormalized log
              probabilities of the span start position.
            - span_end_logits (`torch.FloatTensor`) :
              A tensor of shape `(batch_size, passage_length)` representing unnormalized log
              probabilities of the span end position (inclusive).
            - best_span_scores (`torch.FloatTensor`) :
              The score for each of the best spans.
            - loss (`torch.FloatTensor`, optional) :
              A scalar loss to be optimised, evaluated against `answer_span`.
            - best_span (`torch.IntTensor`, optional) :
              Provided when not in train mode and sufficient metadata given for the instance.
              The result of a constrained inference over `span_start_logits` and
              `span_end_logits` to find the most probable span.  Shape is `(batch_size, 2)`
              and each offset is a token index, unless the best span for an instance
              was predicted to be the `[CLS]` token, in which case the span will be (-1, -1).
            - best_span_str (`List[str]`, optional) :
              Provided when not in train mode and sufficient metadata given for the instance.
              This is the string from the original passage that the model thinks is the best answer
              to the question.

        """
        embedded_question = self._text_field_embedder(question_with_context)
        # shape: (batch_size, sequence_length, 2)
        logits = self._linear_layer(embedded_question)
        # shape: (batch_size, sequence_length, 1)
        span_start_logits, span_end_logits = logits.split(1, dim=-1)
        # shape: (batch_size, sequence_length)
        span_start_logits = span_start_logits.squeeze(-1)
        # shape: (batch_size, sequence_length)
        span_end_logits = span_end_logits.squeeze(-1)

        # Create a mask for `question_with_context` to mask out tokens that are not part
        # of the context.
        # shape: (batch_size, sequence_length)
        possible_answer_mask = torch.zeros_like(
            get_token_ids_from_text_field_tensors(question_with_context),
            dtype=torch.bool)
        for i, (start, end) in enumerate(context_span):
            possible_answer_mask[i, start:end + 1] = True
            # Also unmask the [CLS] token since that token is used to indicate that
            # the question is impossible.
            possible_answer_mask[
                i, 0 if cls_index is None else cls_index[i]] = True

        # Replace the masked values with a very negative constant since we're in log-space.
        # shape: (batch_size, sequence_length)
        span_start_logits = replace_masked_values_with_big_negative_number(
            span_start_logits, possible_answer_mask)
        # shape: (batch_size, sequence_length)
        span_end_logits = replace_masked_values_with_big_negative_number(
            span_end_logits, possible_answer_mask)

        # Now calculate the best span.
        # shape: (batch_size, 2)
        best_spans = get_best_span(span_start_logits, span_end_logits)

        # Sum the span start score with the span end score to get an overall score for the span.
        # shape: (batch_size,)
        best_span_scores = torch.gather(
            span_start_logits, 1,
            best_spans[:, 0].unsqueeze(1)) + torch.gather(
                span_end_logits, 1, best_spans[:, 1].unsqueeze(1))
        best_span_scores = best_span_scores.squeeze(1)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_end_logits": span_end_logits,
            "best_span_scores": best_span_scores,
        }

        # Compute the loss.
        if answer_span is not None:
            output_dict["loss"] = self._evaluate_span(best_spans,
                                                      span_start_logits,
                                                      span_end_logits,
                                                      answer_span)

        # Gather the string of the best span and compute the EM and F1 against the gold span,
        # if given.
        if not self.training and metadata is not None:
            (
                output_dict["best_span_str"],
                output_dict["best_span"],
            ) = self._collect_best_span_strings(best_spans, context_span,
                                                metadata, cls_index)

        return output_dict
Exemplo n.º 5
0
    def forward(  # type: ignore
        self,
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        number_indices: torch.LongTensor,
        answer_as_passage_spans: torch.LongTensor = None,
        answer_as_question_spans: torch.LongTensor = None,
        answer_as_add_sub_expressions: torch.LongTensor = None,
        answer_as_counts: torch.LongTensor = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        question_mask = util.get_text_field_mask(question)
        passage_mask = util.get_text_field_mask(passage)
        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(
            self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(
            self._embedding_proj_layer(embedded_passage))

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(
            embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(
            embedded_passage)

        encoded_question = self._dropout(
            self._phrase_layer(projected_embedded_question, question_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(projected_embedded_passage, passage_mask))

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(
            passage_question_similarity, question_mask, memory_efficient=True)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(
            passage_question_similarity.transpose(1, 2),
            passage_mask,
            memory_efficient=True)

        # Shape: (batch_size, passage_length, passage_length)
        passsage_attention_over_attention = torch.bmm(
            passage_question_attention, question_passage_attention)
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(
            encoded_passage, passsage_attention_over_attention)

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
            torch.cat(
                [
                    encoded_passage,
                    passage_question_vectors,
                    encoded_passage * passage_question_vectors,
                    encoded_passage * passage_passage_vectors,
                ],
                dim=-1,
            ))

        # The recurrent modeling layers. Since these layers share the same parameters,
        # we don't construct them conditioned on answering abilities.
        modeled_passage_list = [
            self._modeling_proj_layer(merged_passage_attention_vectors)
        ]
        for _ in range(4):
            modeled_passage = self._dropout(
                self._modeling_layer(modeled_passage_list[-1], passage_mask))
            modeled_passage_list.append(modeled_passage)
        # Pop the first one, which is input
        modeled_passage_list.pop(0)

        # The first modeling layer is used to calculate the vector representation of passage
        passage_weights = self._passage_weights_predictor(
            modeled_passage_list[0]).squeeze(-1)
        passage_weights = masked_softmax(passage_weights, passage_mask)
        passage_vector = util.weighted_sum(modeled_passage_list[0],
                                           passage_weights)
        # The vector representation of question is calculated based on the unmatched encoding,
        # because we may want to infer the answer ability only based on the question words.
        question_weights = self._question_weights_predictor(
            encoded_question).squeeze(-1)
        question_weights = masked_softmax(question_weights, question_mask)
        question_vector = util.weighted_sum(encoded_question, question_weights)

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = self._answer_ability_predictor(
                torch.cat([passage_vector, question_vector], -1))
            answer_ability_log_probs = torch.nn.functional.log_softmax(
                answer_ability_logits, -1)
            best_answer_ability = torch.argmax(answer_ability_log_probs, 1)

        if "counting" in self.answering_abilities:
            # Shape: (batch_size, 10)
            count_number_logits = self._count_number_predictor(passage_vector)
            count_number_log_probs = torch.nn.functional.log_softmax(
                count_number_logits, -1)
            # Info about the best count number prediction
            # Shape: (batch_size,)
            best_count_number = torch.argmax(count_number_log_probs, -1)
            best_count_log_prob = torch.gather(
                count_number_log_probs, 1,
                best_count_number.unsqueeze(-1)).squeeze(-1)
            if len(self.answering_abilities) > 1:
                best_count_log_prob += answer_ability_log_probs[:, self.
                                                                _counting_index]

        if "passage_span_extraction" in self.answering_abilities:
            # Shape: (batch_size, passage_length, modeling_dim * 2))
            passage_for_span_start = torch.cat(
                [modeled_passage_list[0], modeled_passage_list[1]], dim=-1)
            # Shape: (batch_size, passage_length)
            passage_span_start_logits = self._passage_span_start_predictor(
                passage_for_span_start).squeeze(-1)
            # Shape: (batch_size, passage_length, modeling_dim * 2)
            passage_for_span_end = torch.cat(
                [modeled_passage_list[0], modeled_passage_list[2]], dim=-1)
            # Shape: (batch_size, passage_length)
            passage_span_end_logits = self._passage_span_end_predictor(
                passage_for_span_end).squeeze(-1)
            # Shape: (batch_size, passage_length)
            passage_span_start_log_probs = util.masked_log_softmax(
                passage_span_start_logits, passage_mask)
            passage_span_end_log_probs = util.masked_log_softmax(
                passage_span_end_logits, passage_mask)

            # Info about the best passage span prediction
            passage_span_start_logits = replace_masked_values_with_big_negative_number(
                passage_span_start_logits, passage_mask)
            passage_span_end_logits = replace_masked_values_with_big_negative_number(
                passage_span_end_logits, passage_mask)
            # Shape: (batch_size, 2)
            best_passage_span = get_best_span(passage_span_start_logits,
                                              passage_span_end_logits)
            # Shape: (batch_size, 2)
            best_passage_start_log_probs = torch.gather(
                passage_span_start_log_probs, 1,
                best_passage_span[:, 0].unsqueeze(-1)).squeeze(-1)
            best_passage_end_log_probs = torch.gather(
                passage_span_end_log_probs, 1,
                best_passage_span[:, 1].unsqueeze(-1)).squeeze(-1)
            # Shape: (batch_size,)
            best_passage_span_log_prob = best_passage_start_log_probs + best_passage_end_log_probs
            if len(self.answering_abilities) > 1:
                best_passage_span_log_prob += answer_ability_log_probs[:, self.
                                                                       _passage_span_extraction_index]

        if "question_span_extraction" in self.answering_abilities:
            # Shape: (batch_size, question_length)
            encoded_question_for_span_prediction = torch.cat(
                [
                    encoded_question,
                    passage_vector.unsqueeze(1).repeat(
                        1, encoded_question.size(1), 1),
                ],
                -1,
            )
            question_span_start_logits = self._question_span_start_predictor(
                encoded_question_for_span_prediction).squeeze(-1)
            # Shape: (batch_size, question_length)
            question_span_end_logits = self._question_span_end_predictor(
                encoded_question_for_span_prediction).squeeze(-1)
            question_span_start_log_probs = util.masked_log_softmax(
                question_span_start_logits, question_mask)
            question_span_end_log_probs = util.masked_log_softmax(
                question_span_end_logits, question_mask)

            # Info about the best question span prediction
            question_span_start_logits = replace_masked_values_with_big_negative_number(
                question_span_start_logits, question_mask)
            question_span_end_logits = replace_masked_values_with_big_negative_number(
                question_span_end_logits, question_mask)
            # Shape: (batch_size, 2)
            best_question_span = get_best_span(question_span_start_logits,
                                               question_span_end_logits)
            # Shape: (batch_size, 2)
            best_question_start_log_probs = torch.gather(
                question_span_start_log_probs, 1,
                best_question_span[:, 0].unsqueeze(-1)).squeeze(-1)
            best_question_end_log_probs = torch.gather(
                question_span_end_log_probs, 1,
                best_question_span[:, 1].unsqueeze(-1)).squeeze(-1)
            # Shape: (batch_size,)
            best_question_span_log_prob = (best_question_start_log_probs +
                                           best_question_end_log_probs)
            if len(self.answering_abilities) > 1:
                best_question_span_log_prob += answer_ability_log_probs[:,
                                                                        self.
                                                                        _question_span_extraction_index]

        if "addition_subtraction" in self.answering_abilities:
            # Shape: (batch_size, # of numbers in the passage)
            number_indices = number_indices.squeeze(-1)
            number_mask = number_indices != -1
            clamped_number_indices = util.replace_masked_values(
                number_indices, number_mask, 0)
            encoded_passage_for_numbers = torch.cat(
                [modeled_passage_list[0], modeled_passage_list[3]], dim=-1)
            # Shape: (batch_size, # of numbers in the passage, encoding_dim)
            encoded_numbers = torch.gather(
                encoded_passage_for_numbers,
                1,
                clamped_number_indices.unsqueeze(-1).expand(
                    -1, -1, encoded_passage_for_numbers.size(-1)),
            )
            # Shape: (batch_size, # of numbers in the passage)
            encoded_numbers = torch.cat(
                [
                    encoded_numbers,
                    passage_vector.unsqueeze(1).repeat(
                        1, encoded_numbers.size(1), 1),
                ],
                -1,
            )

            # Shape: (batch_size, # of numbers in the passage, 3)
            number_sign_logits = self._number_sign_predictor(encoded_numbers)
            number_sign_log_probs = torch.nn.functional.log_softmax(
                number_sign_logits, -1)

            # Shape: (batch_size, # of numbers in passage).
            best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1)
            # For padding numbers, the best sign masked as 0 (not included).
            best_signs_for_numbers = util.replace_masked_values(
                best_signs_for_numbers, number_mask, 0)
            # Shape: (batch_size, # of numbers in passage)
            best_signs_log_probs = torch.gather(
                number_sign_log_probs, 2,
                best_signs_for_numbers.unsqueeze(-1)).squeeze(-1)
            # the probs of the masked positions should be 1 so that it will not affect the joint probability
            # TODO: this is not quite right, since if there are many numbers in the passage,
            # TODO: the joint probability would be very small.
            best_signs_log_probs = util.replace_masked_values(
                best_signs_log_probs, number_mask, 0)
            # Shape: (batch_size,)
            best_combination_log_prob = best_signs_log_probs.sum(-1)
            if len(self.answering_abilities) > 1:
                best_combination_log_prob += answer_ability_log_probs[:, self.
                                                                      _addition_subtraction_index]

        output_dict = {}

        # If answer is given, compute the loss.
        if (answer_as_passage_spans is not None
                or answer_as_question_spans is not None
                or answer_as_add_sub_expressions is not None
                or answer_as_counts is not None):

            log_marginal_likelihood_list = []

            for answering_ability in self.answering_abilities:
                if answering_ability == "passage_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_passage_span_starts = answer_as_passage_spans[:, :, 0]
                    gold_passage_span_ends = answer_as_passage_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_passage_span_mask = gold_passage_span_starts != -1
                    clamped_gold_passage_span_starts = util.replace_masked_values(
                        gold_passage_span_starts, gold_passage_span_mask, 0)
                    clamped_gold_passage_span_ends = util.replace_masked_values(
                        gold_passage_span_ends, gold_passage_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_span_starts = torch.gather(
                        passage_span_start_log_probs, 1,
                        clamped_gold_passage_span_starts)
                    log_likelihood_for_passage_span_ends = torch.gather(
                        passage_span_end_log_probs, 1,
                        clamped_gold_passage_span_ends)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_spans = (
                        log_likelihood_for_passage_span_starts +
                        log_likelihood_for_passage_span_ends)
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_passage_spans = replace_masked_values_with_big_negative_number(
                        log_likelihood_for_passage_spans,
                        gold_passage_span_mask,
                    )
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_passage_span = util.logsumexp(
                        log_likelihood_for_passage_spans)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_passage_span)

                elif answering_ability == "question_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_question_span_starts = answer_as_question_spans[:, :,
                                                                         0]
                    gold_question_span_ends = answer_as_question_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_question_span_mask = gold_question_span_starts != -1
                    clamped_gold_question_span_starts = util.replace_masked_values(
                        gold_question_span_starts, gold_question_span_mask, 0)
                    clamped_gold_question_span_ends = util.replace_masked_values(
                        gold_question_span_ends, gold_question_span_mask, 0)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_question_span_starts = torch.gather(
                        question_span_start_log_probs, 1,
                        clamped_gold_question_span_starts)
                    log_likelihood_for_question_span_ends = torch.gather(
                        question_span_end_log_probs, 1,
                        clamped_gold_question_span_ends)
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_question_spans = (
                        log_likelihood_for_question_span_starts +
                        log_likelihood_for_question_span_ends)
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_question_spans = replace_masked_values_with_big_negative_number(
                        log_likelihood_for_question_spans,
                        gold_question_span_mask,
                    )
                    # Shape: (batch_size, )

                    log_marginal_likelihood_for_question_span = util.logsumexp(
                        log_likelihood_for_question_spans)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_question_span)

                elif answering_ability == "addition_subtraction":
                    # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here.
                    # Shape: (batch_size, # of combinations)
                    gold_add_sub_mask = answer_as_add_sub_expressions.sum(
                        -1) > 0
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    gold_add_sub_signs = answer_as_add_sub_expressions.transpose(
                        1, 2)
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    log_likelihood_for_number_signs = torch.gather(
                        number_sign_log_probs, 2, gold_add_sub_signs)
                    # the log likelihood of the masked positions should be 0
                    # so that it will not affect the joint probability
                    log_likelihood_for_number_signs = util.replace_masked_values(
                        log_likelihood_for_number_signs,
                        number_mask.unsqueeze(-1), 0)
                    # Shape: (batch_size, # of combinations)
                    log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum(
                        1)
                    # For those padded combinations, we set their log probabilities to be very small negative value
                    log_likelihood_for_add_subs = replace_masked_values_with_big_negative_number(
                        log_likelihood_for_add_subs, gold_add_sub_mask)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_add_sub = util.logsumexp(
                        log_likelihood_for_add_subs)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_add_sub)

                elif answering_ability == "counting":
                    # Count answers are padded with label -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    # Shape: (batch_size, # of count answers)
                    gold_count_mask = answer_as_counts != -1
                    # Shape: (batch_size, # of count answers)
                    clamped_gold_counts = util.replace_masked_values(
                        answer_as_counts, gold_count_mask, 0)
                    log_likelihood_for_counts = torch.gather(
                        count_number_log_probs, 1, clamped_gold_counts)
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_counts = replace_masked_values_with_big_negative_number(
                        log_likelihood_for_counts, gold_count_mask)
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_count = util.logsumexp(
                        log_likelihood_for_counts)
                    log_marginal_likelihood_list.append(
                        log_marginal_likelihood_for_count)

                else:
                    raise ValueError(
                        f"Unsupported answering ability: {answering_ability}")

            if len(self.answering_abilities) > 1:
                # Add the ability probabilities if there are more than one abilities
                all_log_marginal_likelihoods = torch.stack(
                    log_marginal_likelihood_list, dim=-1)
                all_log_marginal_likelihoods = (all_log_marginal_likelihoods +
                                                answer_ability_log_probs)
                marginal_log_likelihood = util.logsumexp(
                    all_log_marginal_likelihoods)
            else:
                marginal_log_likelihood = log_marginal_likelihood_list[0]

            output_dict["loss"] = -marginal_log_likelihood.mean()

        # Compute the metrics and add the tokenized input to the output.
        if metadata is not None:
            output_dict["question_id"] = []
            output_dict["answer"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]["question_tokens"])
                passage_tokens.append(metadata[i]["passage_tokens"])

                if len(self.answering_abilities) > 1:
                    predicted_ability_str = self.answering_abilities[
                        best_answer_ability[i].detach().cpu().numpy()]
                else:
                    predicted_ability_str = self.answering_abilities[0]

                answer_json: Dict[str, Any] = {}

                # We did not consider multi-mention answers here
                if predicted_ability_str == "passage_span_extraction":
                    answer_json["answer_type"] = "passage_span"
                    passage_str = metadata[i]["original_passage"]
                    offsets = metadata[i]["passage_token_offsets"]
                    predicted_span = tuple(
                        best_passage_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    predicted_answer = passage_str[start_offset:end_offset]
                    answer_json["value"] = predicted_answer
                    answer_json["spans"] = [(start_offset, end_offset)]
                elif predicted_ability_str == "question_span_extraction":
                    answer_json["answer_type"] = "question_span"
                    question_str = metadata[i]["original_question"]
                    offsets = metadata[i]["question_token_offsets"]
                    predicted_span = tuple(
                        best_question_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    predicted_answer = question_str[start_offset:end_offset]
                    answer_json["value"] = predicted_answer
                    answer_json["spans"] = [(start_offset, end_offset)]
                elif (predicted_ability_str == "addition_subtraction"
                      ):  # plus_minus combination answer
                    answer_json["answer_type"] = "arithmetic"
                    original_numbers = metadata[i]["original_numbers"]
                    sign_remap = {0: 0, 1: 1, 2: -1}
                    predicted_signs = [
                        sign_remap[it] for it in
                        best_signs_for_numbers[i].detach().cpu().numpy()
                    ]
                    result = sum([
                        sign * number for sign, number in zip(
                            predicted_signs, original_numbers)
                    ])
                    predicted_answer = str(result)
                    offsets = metadata[i]["passage_token_offsets"]
                    number_indices = metadata[i]["number_indices"]
                    number_positions = [
                        offsets[index] for index in number_indices
                    ]
                    answer_json["numbers"] = []
                    for offset, value, sign in zip(number_positions,
                                                   original_numbers,
                                                   predicted_signs):
                        answer_json["numbers"].append({
                            "span": offset,
                            "value": value,
                            "sign": sign
                        })
                    if number_indices[-1] == -1:
                        # There is a dummy 0 number at position -1 added in some cases; we are
                        # removing that here.
                        answer_json["numbers"].pop()
                    answer_json["value"] = result
                elif predicted_ability_str == "counting":
                    answer_json["answer_type"] = "count"
                    predicted_count = best_count_number[i].detach().cpu(
                    ).numpy()
                    predicted_answer = str(predicted_count)
                    answer_json["count"] = predicted_count
                else:
                    raise ValueError(
                        f"Unsupported answer ability: {predicted_ability_str}")

                output_dict["question_id"].append(metadata[i]["question_id"])
                output_dict["answer"].append(answer_json)
                answer_annotations = metadata[i].get("answer_annotations", [])
                if answer_annotations:
                    self._drop_metrics(predicted_answer, answer_annotations)
            # This is used for the demo.
            output_dict[
                "passage_question_attention"] = passage_question_attention
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens
        return output_dict
Exemplo n.º 6
0
    def forward(  # type: ignore
        self,
        question_with_context: Dict[str, Dict[str, torch.LongTensor]],
        context_span: torch.IntTensor,
        answer_span: Optional[torch.IntTensor] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        question_with_context : Dict[str, torch.LongTensor]
            From a ``TextField``. The model assumes that this text field contains the context followed by the
            question. It further assumes that the tokens have type ids set such that any token that can be part of
            the answer (i.e., tokens from the context) has type id 0, and any other token (including [CLS] and
            [SEP]) has type id 1.
        context_span : ``torch.IntTensor``
            From a ``SpanField``. This marks the span of word pieces in ``question`` from which answers can come.
        answer_span : ``torch.IntTensor``, optional
            From a ``SpanField``. This is the thing we are trying to predict - the span of text that marks the
            answer. If given, we compute a loss that gets included in the output directory.
        metadata : ``List[Dict[str, Any]]``, optional
            If present, this should contain the question id, and the original texts of context, question, tokenized
            version of both, and a list of possible answers. The length of the ``metadata`` list should be the
            batch size, and each dictionary should have the keys ``id``, ``question``, ``context``,
            ``question_tokens``, ``context_tokens``, and ``answers``.

        Returns
        -------
        An output dictionary consisting of:
        span_start_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : torch.FloatTensor
            The result of ``softmax(span_start_logits)``.
        span_end_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : torch.FloatTensor
            The result of ``softmax(span_end_logits)``.
        best_span : torch.IntTensor
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        best_span_scores : torch.FloatTensor
            The score for each of the best spans.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        best_span_str : List[str]
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        embedded_question = self._text_field_embedder(question_with_context)
        logits = self._linear_layer(embedded_question)
        span_start_logits, span_end_logits = logits.split(1, dim=-1)
        span_start_logits = span_start_logits.squeeze(-1)
        span_end_logits = span_end_logits.squeeze(-1)

        possible_answer_mask = torch.zeros_like(
            get_token_ids_from_text_field_tensors(question_with_context),
            dtype=torch.bool)
        for i, (start, end) in enumerate(context_span):
            possible_answer_mask[i, start:end + 1] = True

        # Replace the masked values with a very negative constant.
        span_start_logits = replace_masked_values_with_big_negative_number(
            span_start_logits, possible_answer_mask)
        span_end_logits = replace_masked_values_with_big_negative_number(
            span_end_logits, possible_answer_mask)
        span_start_probs = torch.nn.functional.softmax(span_start_logits,
                                                       dim=-1)
        span_end_probs = torch.nn.functional.softmax(span_end_logits, dim=-1)
        best_spans = get_best_span(span_start_logits, span_end_logits)
        best_span_scores = torch.gather(
            span_start_logits, 1,
            best_spans[:, 0].unsqueeze(1)) + torch.gather(
                span_end_logits, 1, best_spans[:, 1].unsqueeze(1))
        best_span_scores = best_span_scores.squeeze(1)

        output_dict = {
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            "best_span": best_spans,
            "best_span_scores": best_span_scores,
        }

        # Compute the loss for training.
        if answer_span is not None:
            span_start = answer_span[:, 0]
            span_end = answer_span[:, 1]
            span_mask = span_start != -1
            self._span_accuracy(best_spans, answer_span,
                                span_mask.unsqueeze(-1).expand_as(best_spans))

            start_loss = cross_entropy(span_start_logits,
                                       span_start,
                                       ignore_index=-1)
            big_constant = min(torch.finfo(start_loss.dtype).max, 1e9)
            if torch.any(start_loss > big_constant):
                logger.critical("Start loss too high (%r)", start_loss)
                logger.critical("span_start_logits: %r", span_start_logits)
                logger.critical("span_start: %r", span_start)
                assert False

            end_loss = cross_entropy(span_end_logits,
                                     span_end,
                                     ignore_index=-1)
            if torch.any(end_loss > big_constant):
                logger.critical("End loss too high (%r)", end_loss)
                logger.critical("span_end_logits: %r", span_end_logits)
                logger.critical("span_end: %r", span_end)
                assert False

            loss = (start_loss + end_loss) / 2

            self._span_start_accuracy(span_start_logits, span_start, span_mask)
            self._span_end_accuracy(span_end_logits, span_end, span_mask)

            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            best_spans = best_spans.detach().cpu().numpy()

            output_dict["best_span_str"] = []
            context_tokens = []
            for metadata_entry, best_span, cspan in zip(
                    metadata, best_spans, context_span):
                context_tokens_for_question = metadata_entry["context_tokens"]
                context_tokens.append(context_tokens_for_question)

                best_span -= int(cspan[0])
                assert np.all(best_span >= 0)

                predicted_start, predicted_end = tuple(best_span)

                while (predicted_start >= 0
                       and context_tokens_for_question[predicted_start].idx is
                       None):
                    predicted_start -= 1
                if predicted_start < 0:
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[0]].text}' at index "
                        f"'{best_span[0]}' to an offset in the original text.")
                    character_start = 0
                else:
                    character_start = context_tokens_for_question[
                        predicted_start].idx

                while (predicted_end < len(context_tokens_for_question) and
                       context_tokens_for_question[predicted_end].idx is None):
                    predicted_end += 1
                if predicted_end >= len(context_tokens_for_question):
                    logger.warning(
                        f"Could not map the token '{context_tokens_for_question[best_span[1]].text}' at index "
                        f"'{best_span[1]}' to an offset in the original text.")
                    character_end = len(metadata_entry["context"])
                else:
                    end_token = context_tokens_for_question[predicted_end]
                    character_end = end_token.idx + len(
                        sanitize_wordpiece(end_token.text))

                best_span_string = metadata_entry["context"][
                    character_start:character_end]
                output_dict["best_span_str"].append(best_span_string)

                answers = metadata_entry.get("answers")
                if len(answers) > 0:
                    self._per_instance_metrics(best_span_string, answers)
            output_dict["context_tokens"] = context_tokens
        return output_dict
Exemplo n.º 7
0
    def forward(  # type: ignore
        self,
        question_with_context: Dict[str, Dict[str, torch.LongTensor]],
        context_span: torch.IntTensor,
        yes_no_span: torch.IntTensor = None,
        answer_span: Optional[torch.IntTensor] = None,
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        # Parameters

        question_with_context : `Dict[str, torch.LongTensor]`
            From a ``TextField``. The model assumes that this text field contains the context followed by the
            question. It further assumes that the tokens have type ids set such that any token that can be part of
            the answer (i.e., tokens from the context) has type id 0, and any other token (including [CLS] and
            [SEP]) has type id 1.
        context_span : `torch.IntTensor`
            From a ``SpanField``. This marks the span of word pieces in ``question`` from which answers can come.
        answer_span : `torch.IntTensor`, optional
            From a ``SpanField``. This is the thing we are trying to predict - the span of text that marks the
            answer. If given, we compute a loss that gets included in the output directory.
        metadata : `List[Dict[str, Any]]`, optional
            If present, this should contain the question id, and the original texts of context, question, tokenized
            version of both, and a list of possible answers. The length of the ``metadata`` list should be the
            batch size, and each dictionary should have the keys ``id``, ``question``, ``context``,
            ``question_tokens``, ``context_tokens``, and ``answers``.

        # Returns

        An output dictionary consisting of:

        span_start_logits : `torch.FloatTensor`
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span start position.
        span_start_probs : `torch.FloatTensor`
            The result of `softmax(span_start_logits)`.
        span_end_logits : `torch.FloatTensor`
            A tensor of shape ``(batch_size, passage_length)`` representing unnormalized log
            probabilities of the span end position (inclusive).
        span_end_probs : `torch.FloatTensor`
            The result of ``softmax(span_end_logits)``.
        best_span : `torch.IntTensor`
            The result of a constrained inference over ``span_start_logits`` and
            ``span_end_logits`` to find the most probable span.  Shape is ``(batch_size, 2)``
            and each offset is a token index.
        best_span_scores : `torch.FloatTensor`
            The score for each of the best spans.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        best_span_str : `List[str]`
            If sufficient metadata was provided for the instances in the batch, we also return the
            string from the original passage that the model thinks is the best answer to the
            question.
        """
        outputs = self._qa_model(**question_with_context)
        span_start_logits = outputs["start_logits"]
        span_end_logits = outputs["end_logits"]

        with torch.no_grad():
            possible_answer_mask = torch.zeros_like(
                question_with_context["input_ids"],
                dtype=torch.bool,
            )
            if not self.force_yes_no:
                for i, (start, end) in enumerate(context_span):
                    if start != -1 and end != -1:
                        possible_answer_mask[i, start:end + 1] = True

            if yes_no_span is not None:
                for i, (start, end) in enumerate(yes_no_span):
                    if start != -1 and end != -1:
                        possible_answer_mask[i, start:end + 1] = True

            for i in range(len(possible_answer_mask)):
                assert any(possible_answer_mask[i])

            # Replace the masked values with a very negative constant.
            context_masked_span_start_logits = replace_masked_values_with_big_negative_number(
                span_start_logits, possible_answer_mask)
            context_masked_span_end_logits = replace_masked_values_with_big_negative_number(
                span_end_logits, possible_answer_mask)
            best_spans = get_best_span(context_masked_span_start_logits,
                                       context_masked_span_end_logits)
            best_span_scores = torch.gather(
                context_masked_span_start_logits, 1,
                best_spans[:, 0].unsqueeze(1)) + torch.gather(
                    context_masked_span_end_logits, 1,
                    best_spans[:, 1].unsqueeze(1))
            best_span_scores = best_span_scores.squeeze(1)

            output_dict = {
                "best_span":
                best_spans,
                "best_span_scores":
                best_span_scores,
                "yes_scores":
                span_start_logits[:, yes_no_span[:, 0]] +
                span_end_logits[:, yes_no_span[:, 0]],
                "no_scores":
                span_start_logits[:, yes_no_span[:, 1]] +
                span_end_logits[:, yes_no_span[:, 1]],
            }
            if self._enable_no_answer:
                no_answer_scores = span_start_logits[:, 0] + span_end_logits[:,
                                                                             0]
                output_dict.update({"no_answer_scores": no_answer_scores})

        # Compute metrics and set loss
        if answer_span is not None:
            span_start = answer_span[:, 0]
            span_end = answer_span[:, 1]
            span_mask = span_start != -1
            if self._enable_no_answer:
                span_mask &= span_start != 0

            self._span_accuracy(best_spans, answer_span,
                                span_mask.unsqueeze(-1).expand_as(best_spans))

            self._span_start_accuracy(context_masked_span_start_logits,
                                      span_start, span_mask)
            self._span_end_accuracy(context_masked_span_end_logits, span_end,
                                    span_mask)

            if self._enable_no_answer:
                possible_answer_mask[:, 0] = True
            # Replace the masked values with a very negative constant.
            masked_span_start_logits = replace_masked_values_with_big_negative_number(
                span_start_logits, possible_answer_mask)
            masked_span_end_logits = replace_masked_values_with_big_negative_number(
                span_end_logits, possible_answer_mask)

            loss_fct = CrossEntropyLoss(ignore_index=-1)
            start_loss = loss_fct(masked_span_start_logits, span_start)
            end_loss = loss_fct(masked_span_end_logits, span_end)
            total_loss = (start_loss + end_loss) / 2
            output_dict["loss"] = total_loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if metadata is not None:
            best_spans = best_spans.detach().cpu().numpy()

            output_dict["best_span_str"] = []
            for i, (metadata_entry,
                    best_span) in enumerate(zip(metadata, best_spans)):
                best_span_string = TokensInterpreter.extract_span_string_from_origin_texts(
                    Span(*best_span),
                    [
                        metadata_entry["modified_question"],
                        metadata_entry["context"]
                    ],
                    metadata_entry["offset_mapping"],
                    metadata_entry["special_tokens_mask"],
                )

                if self.force_yes_no:
                    if output_dict["yes_scores"][i].item(
                    ) > output_dict["no_scores"][i].item():
                        overriding_best_span_string = "yes"
                    else:
                        overriding_best_span_string = "no"
                    if overriding_best_span_string != best_span_string:
                        best_span_string = overriding_best_span_string

                output_dict["best_span_str"].append(best_span_string)

                answers = metadata_entry.get("answers")
                if answers is not None and len(answers) > 0:
                    if self._enable_no_answer:
                        final_pred = (best_span_string if
                                      best_span_scores[i] > no_answer_scores[i]
                                      else "")
                    else:
                        final_pred = best_span_string

                    if metadata_entry["is_boolq"]:
                        self._boolq_accuracy(final_pred, answers)
                    else:
                        self._per_instance_metrics(final_pred, answers)

        return output_dict